object_storage_proxy/utils/
validator.rs1use pyo3::{PyObject, Python};
2use tokio::{sync::Mutex, task};
3use tracing::{debug, error};
4
5use std::{
6 collections::HashMap,
7 sync::{Arc, RwLock},
8 time::{Duration, Instant},
9};
10
11use crate::utils::functions::callable_accepts_request;
12
13#[derive(Clone, Debug)]
14struct AuthEntry {
15 authorized: bool,
16 expires_at: Instant,
17}
18
19#[derive(Clone, Debug)]
28pub struct AuthCache {
29 inner: Arc<RwLock<HashMap<String, AuthEntry>>>,
30 locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
31}
32
33impl Default for AuthCache {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39impl AuthCache {
40 pub fn new() -> Self {
41 AuthCache {
42 inner: Arc::new(RwLock::new(HashMap::new())),
43 locks: Arc::new(Mutex::new(HashMap::new())),
44 }
45 }
46
47 pub async fn get_or_validate<F, Fut, E>(
48 &self,
49 key: &str,
50 ttl: Duration,
51 validator_fn: F,
52 ) -> Result<bool, E>
53 where
54 F: Fn() -> Fut + Send + Sync + 'static,
55 Fut: std::future::Future<Output = Result<bool, E>> + Send,
56 E: std::fmt::Debug,
57 {
58 if let Some(entry) = {
59 let map = self.inner.read().unwrap();
60 map.get(key).cloned()
61 } && Instant::now() < entry.expires_at
62 {
63 debug!("Cache hit for key.");
64 return Ok(entry.authorized);
65 }
66 debug!("Cache miss for key. Validating authorization...");
67 let key_lock = {
68 let mut locks_map = self.locks.lock().await;
69 locks_map
70 .entry(key.to_string())
71 .or_insert_with(|| Arc::new(Mutex::new(())))
72 .clone()
73 };
74 let _guard = key_lock.lock().await;
75
76 if let Some(entry) = {
77 let map = self.inner.read().expect("lock poisoned");
78 map.get(key).cloned()
79 } && Instant::now() < entry.expires_at
80 {
81 return Ok(entry.authorized);
82 }
83
84 let decision = validator_fn().await?;
85
86 {
87 let mut map = self.inner.write().expect("lock poisoned");
88 map.insert(
89 key.to_string(),
90 AuthEntry {
91 authorized: decision,
92 expires_at: Instant::now() + ttl,
93 },
94 );
95 }
96 debug!("Authorization cache updated for key.");
97 Ok(decision)
98 }
99
100 pub fn insert(&self, key: String, authorized: bool, ttl: Duration) {
102 let entry = AuthEntry {
103 authorized,
104 expires_at: Instant::now() + ttl,
105 };
106 let mut map = self.inner.write().expect("lock poisoned");
107 map.insert(key, entry);
108 }
109
110 pub fn invalidate(&self, key: &str) {
112 let mut map = self.inner.write().expect("lock poisoned");
113 map.remove(key);
114 }
115}
116
117pub async fn validate_request(
127 token: &str,
128 bucket: &str,
129 request: &HashMap<String, String>,
130 callback: PyObject,
131) -> Result<bool, String> {
132 let token = token.to_string();
133 let bucket = bucket.to_string();
134
135 let req = request
136 .iter()
137 .map(|(k, v)| (k.clone(), v.clone()))
138 .collect::<HashMap<String, String>>();
139
140 debug!("request details sent to Python callable: {:?}", &req);
141
142 let takes_request = Python::with_gil(|py| {
143 callable_accepts_request(py, &callback)
144 .map_err(|e| format!("Invalid callable signature: {:?}", e))
145 });
146
147 if takes_request.is_err() {
148 return Err(format!("Invalid callable signature: {:?}", takes_request));
149 }
150 let takes_request = takes_request.expect("checked above");
151
152 debug!("Python callable can take request: {:?}", &takes_request);
153
154 let authorized = if takes_request {
155 task::spawn_blocking(move || {
156 Python::with_gil(|py| {
157 match callback.call1(py, (token.as_str(), bucket.as_str(), &req)) {
158 Ok(result_obj) => result_obj
159 .extract::<bool>(py)
160 .map_err(|_| "Failed to extract boolean".to_string()),
161 Err(e) => {
162 error!("Python callback error: {:?}", e);
163 Err("Inner Python exception".to_string())
164 }
165 }
166 })
167 })
168 .await
169 .map_err(|e| format!("Join error: {:?}", e))??
170 } else {
171 task::spawn_blocking(move || {
172 Python::with_gil(
173 |py| match callback.call1(py, (token.as_str(), bucket.as_str())) {
174 Ok(result_obj) => result_obj
175 .extract::<bool>(py)
176 .map_err(|_| "Failed to extract boolean".to_string()),
177 Err(e) => {
178 error!("Python callback error: {:?}", e);
179 Err("Inner Python exception".to_string())
180 }
181 },
182 )
183 })
184 .await
185 .map_err(|e| format!("Join error: {:?}", e))??
186 };
187
188 Ok(authorized)
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[tokio::test]
196 async fn auth_cache_get_or_validate_behaviors() {
197 let cache = AuthCache::new();
198 let key = "auth_key";
199
200 let calls = Arc::new(Mutex::new(0));
201 let validator = {
202 let calls = Arc::clone(&calls);
203 move || {
204 let calls = Arc::clone(&calls);
205 async move {
206 let mut calls_lock = calls.lock().await;
207 *calls_lock += 1;
208 Ok::<bool, std::convert::Infallible>(true)
209 }
210 }
211 };
212 let res1 = cache
213 .get_or_validate(key, Duration::from_secs(1), validator)
214 .await
215 .unwrap();
216 assert!(res1);
217 assert_eq!(*calls.lock().await, 1);
218
219 let res2 = cache
221 .get_or_validate(key, Duration::from_secs(1), {
222 let calls = Arc::clone(&calls);
223 move || {
224 let calls = Arc::clone(&calls);
225 async move {
226 let mut calls_lock = calls.lock().await;
227 *calls_lock += 1;
228 Ok::<bool, std::convert::Infallible>(false)
229 }
230 }
231 })
232 .await
233 .unwrap();
234 assert!(res2);
235 assert_eq!(*calls.lock().await, 1);
236
237 tokio::time::sleep(Duration::from_secs(2)).await;
239 let res3 = cache
240 .get_or_validate(key, Duration::from_secs(1), || async move {
241 Ok::<bool, std::convert::Infallible>(false)
242 })
243 .await
244 .unwrap();
245 assert!(!res3);
246 }
247}