object_storage_proxy/utils/
validator.rs1use dashmap::DashMap;
2use lru::LruCache;
3use pyo3::{PyObject, Python};
4use tokio::{sync::Mutex, task};
5use tracing::{debug, error};
6
7use std::{
8 collections::HashMap,
9 num::NonZeroUsize,
10 sync::{Arc, RwLock},
11 time::{Duration, Instant},
12};
13
14#[derive(Clone, Debug)]
15struct AuthEntry {
16 authorized: bool,
17 expires_at: Instant,
18}
19
20pub const AUTH_CACHE_DEFAULT_CAPACITY: usize = 1024;
25
26#[derive(Clone, Debug)]
40pub struct AuthCache {
41 inner: Arc<RwLock<LruCache<String, AuthEntry>>>,
42 locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
45}
46
47impl Default for AuthCache {
48 fn default() -> Self {
49 Self::new(AUTH_CACHE_DEFAULT_CAPACITY)
50 }
51}
52
53impl AuthCache {
54 pub fn new(capacity: usize) -> Self {
55 let cap = NonZeroUsize::new(capacity)
56 .unwrap_or(NonZeroUsize::new(AUTH_CACHE_DEFAULT_CAPACITY).unwrap());
57 AuthCache {
58 inner: Arc::new(RwLock::new(LruCache::new(cap))),
59 locks: Arc::new(DashMap::new()),
60 }
61 }
62
63 pub async fn get_or_validate<F, Fut, E>(
64 &self,
65 key: &str,
66 ttl: Duration,
67 validator_fn: F,
68 ) -> Result<bool, E>
69 where
70 F: Fn() -> Fut + Send + Sync + 'static,
71 Fut: std::future::Future<Output = Result<bool, E>> + Send,
72 E: std::fmt::Debug,
73 {
74 if let Some(entry) = {
75 let map = self.inner.read().unwrap();
78 map.peek(key).cloned()
79 } && Instant::now() < entry.expires_at
80 {
81 debug!("Cache hit for key.");
82 return Ok(entry.authorized);
83 }
84 debug!("Cache miss for key. Validating authorization...");
85
86 let key_lock = self
89 .locks
90 .entry(key.to_string())
91 .or_insert_with(|| Arc::new(Mutex::new(())))
92 .clone();
93 let _guard = key_lock.lock().await;
94
95 if let Some(entry) = {
98 let map = self.inner.read().expect("lock poisoned");
99 map.peek(key).cloned()
100 } && Instant::now() < entry.expires_at
101 {
102 return Ok(entry.authorized);
103 }
104
105 let decision = validator_fn().await?;
106
107 {
108 let mut map = self.inner.write().expect("lock poisoned");
109 map.put(
110 key.to_string(),
111 AuthEntry {
112 authorized: decision,
113 expires_at: Instant::now() + ttl,
114 },
115 );
116 }
117 debug!("Authorization cache updated for key.");
118 Ok(decision)
119 }
120
121 pub fn insert(&self, key: String, authorized: bool, ttl: Duration) {
123 let entry = AuthEntry {
124 authorized,
125 expires_at: Instant::now() + ttl,
126 };
127 let mut map = self.inner.write().expect("lock poisoned");
128 map.put(key, entry);
129 }
130
131 pub fn invalidate(&self, key: &str) {
133 let mut map = self.inner.write().expect("lock poisoned");
134 map.pop(key);
135 }
136}
137
138pub async fn validate_request(
144 token: &str,
145 bucket: &str,
146 request: &HashMap<String, String>,
147 callback: PyObject,
148 takes_request: bool,
149) -> Result<bool, String> {
150 let token = token.to_string();
151 let bucket = bucket.to_string();
152
153 let req = request
154 .iter()
155 .map(|(k, v)| (k.clone(), v.clone()))
156 .collect::<HashMap<String, String>>();
157
158 debug!("request details sent to Python callable: {:?}", &req);
159
160 let authorized = if takes_request {
161 task::spawn_blocking(move || {
162 Python::with_gil(|py| {
163 match callback.call1(py, (token.as_str(), bucket.as_str(), &req)) {
164 Ok(result_obj) => result_obj
165 .extract::<bool>(py)
166 .map_err(|_| "Failed to extract boolean".to_string()),
167 Err(e) => {
168 error!("Python callback error: {:?}", e);
169 Err("Inner Python exception".to_string())
170 }
171 }
172 })
173 })
174 .await
175 .map_err(|e| format!("Join error: {:?}", e))??
176 } else {
177 task::spawn_blocking(move || {
178 Python::with_gil(
179 |py| match callback.call1(py, (token.as_str(), bucket.as_str())) {
180 Ok(result_obj) => result_obj
181 .extract::<bool>(py)
182 .map_err(|_| "Failed to extract boolean".to_string()),
183 Err(e) => {
184 error!("Python callback error: {:?}", e);
185 Err("Inner Python exception".to_string())
186 }
187 },
188 )
189 })
190 .await
191 .map_err(|e| format!("Join error: {:?}", e))??
192 };
193
194 Ok(authorized)
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200
201 #[tokio::test]
202 async fn auth_cache_get_or_validate_behaviors() {
203 let cache = AuthCache::new(AUTH_CACHE_DEFAULT_CAPACITY);
204 let key = "auth_key";
205
206 let calls = Arc::new(Mutex::new(0));
207 let validator = {
208 let calls = Arc::clone(&calls);
209 move || {
210 let calls = Arc::clone(&calls);
211 async move {
212 let mut calls_lock = calls.lock().await;
213 *calls_lock += 1;
214 Ok::<bool, std::convert::Infallible>(true)
215 }
216 }
217 };
218 let res1 = cache
219 .get_or_validate(key, Duration::from_secs(1), validator)
220 .await
221 .unwrap();
222 assert!(res1);
223 assert_eq!(*calls.lock().await, 1);
224
225 let res2 = cache
227 .get_or_validate(key, Duration::from_secs(1), {
228 let calls = Arc::clone(&calls);
229 move || {
230 let calls = Arc::clone(&calls);
231 async move {
232 let mut calls_lock = calls.lock().await;
233 *calls_lock += 1;
234 Ok::<bool, std::convert::Infallible>(false)
235 }
236 }
237 })
238 .await
239 .unwrap();
240 assert!(res2);
241 assert_eq!(*calls.lock().await, 1);
242
243 tokio::time::sleep(Duration::from_secs(2)).await;
245 let res3 = cache
246 .get_or_validate(key, Duration::from_secs(1), || async move {
247 Ok::<bool, std::convert::Infallible>(false)
248 })
249 .await
250 .unwrap();
251 assert!(!res3);
252 }
253}