Skip to main content

object_storage_proxy/utils/
validator.rs

1use dashmap::DashMap;
2use lru::LruCache;
3use pyo3::{PyObject, Python};
4use tokio::{sync::Mutex, task};
5use tracing::{debug, error};
6
7use std::{
8    collections::HashMap,
9    num::NonZeroUsize,
10    sync::{Arc, RwLock},
11    time::{Duration, Instant},
12};
13
14#[derive(Clone, Debug)]
15struct AuthEntry {
16    authorized: bool,
17    expires_at: Instant,
18}
19
20/// Default maximum number of distinct `(access_key, bucket, method)` entries
21/// held in the authorization cache.  Once the limit is reached the
22/// least-recently-used entry is evicted automatically — no background sweep
23/// task needed.
24pub const AUTH_CACHE_DEFAULT_CAPACITY: usize = 1024;
25
26/// A time-bounded, capacity-limited LRU cache for authorization decisions.
27///
28/// Wraps arbitrary async validator functions so that the (potentially
29/// expensive) Python callback is only invoked once per `(access_key, bucket,
30/// method)` tuple within the configured TTL window.
31///
32/// Memory is bounded by `capacity`: when the limit is reached the
33/// least-recently-used entry is evicted automatically (TODO perf-4: done).
34///
35/// Concurrent cache misses for the **same** key are serialised via a per-key
36/// [`Mutex`] to avoid thundering-herd stampedes.  Misses for **different** keys
37/// are fully concurrent — the per-key lock map is backed by a [`DashMap`] so
38/// no single global lock is held (TODO perf-3: done).
39#[derive(Clone, Debug)]
40pub struct AuthCache {
41    inner: Arc<RwLock<LruCache<String, AuthEntry>>>,
42    /// Per-key mutex map — DashMap so concurrent misses for different keys
43    /// never contend on a shared lock.
44    locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
45}
46
47impl Default for AuthCache {
48    fn default() -> Self {
49        Self::new(AUTH_CACHE_DEFAULT_CAPACITY)
50    }
51}
52
53impl AuthCache {
54    pub fn new(capacity: usize) -> Self {
55        let cap = NonZeroUsize::new(capacity)
56            .unwrap_or(NonZeroUsize::new(AUTH_CACHE_DEFAULT_CAPACITY).unwrap());
57        AuthCache {
58            inner: Arc::new(RwLock::new(LruCache::new(cap))),
59            locks: Arc::new(DashMap::new()),
60        }
61    }
62
63    pub async fn get_or_validate<F, Fut, E>(
64        &self,
65        key: &str,
66        ttl: Duration,
67        validator_fn: F,
68    ) -> Result<bool, E>
69    where
70        F: Fn() -> Fut + Send + Sync + 'static,
71        Fut: std::future::Future<Output = Result<bool, E>> + Send,
72        E: std::fmt::Debug,
73    {
74        if let Some(entry) = {
75            // LruCache::peek does not promote the entry, preserving LRU order
76            // on a read-only hit — no write lock needed.
77            let map = self.inner.read().unwrap();
78            map.peek(key).cloned()
79        } && Instant::now() < entry.expires_at
80        {
81            debug!("Cache hit for key.");
82            return Ok(entry.authorized);
83        }
84        debug!("Cache miss for key. Validating authorization...");
85
86        // Obtain (or lazily create) the per-key mutex without holding a global
87        // lock across the async validation call.
88        let key_lock = self
89            .locks
90            .entry(key.to_string())
91            .or_insert_with(|| Arc::new(Mutex::new(())))
92            .clone();
93        let _guard = key_lock.lock().await;
94
95        // Double-checked locking: another task may have populated the entry
96        // while we were waiting for the per-key mutex.
97        if let Some(entry) = {
98            let map = self.inner.read().expect("lock poisoned");
99            map.peek(key).cloned()
100        } && Instant::now() < entry.expires_at
101        {
102            return Ok(entry.authorized);
103        }
104
105        let decision = validator_fn().await?;
106
107        {
108            let mut map = self.inner.write().expect("lock poisoned");
109            map.put(
110                key.to_string(),
111                AuthEntry {
112                    authorized: decision,
113                    expires_at: Instant::now() + ttl,
114                },
115            );
116        }
117        debug!("Authorization cache updated for key.");
118        Ok(decision)
119    }
120
121    /// Pre-populate the cache with a known decision for `key`.
122    pub fn insert(&self, key: String, authorized: bool, ttl: Duration) {
123        let entry = AuthEntry {
124            authorized,
125            expires_at: Instant::now() + ttl,
126        };
127        let mut map = self.inner.write().expect("lock poisoned");
128        map.put(key, entry);
129    }
130
131    /// Evict the cached entry for `key`, forcing re-validation on the next request.
132    pub fn invalidate(&self, key: &str) {
133        let mut map = self.inner.write().expect("lock poisoned");
134        map.pop(key);
135    }
136}
137
138/// Invoke the Python validator callback for a single request.
139///
140/// `takes_request` must be pre-computed once (e.g. at server startup via
141/// [`callable_accepts_request`]) and passed here to avoid re-running
142/// `inspect.signature` on every cache miss.
143pub async fn validate_request(
144    token: &str,
145    bucket: &str,
146    request: &HashMap<String, String>,
147    callback: PyObject,
148    takes_request: bool,
149) -> Result<bool, String> {
150    let token = token.to_string();
151    let bucket = bucket.to_string();
152
153    let req = request
154        .iter()
155        .map(|(k, v)| (k.clone(), v.clone()))
156        .collect::<HashMap<String, String>>();
157
158    debug!("request details sent to Python callable: {:?}", &req);
159
160    let authorized = if takes_request {
161        task::spawn_blocking(move || {
162            Python::with_gil(|py| {
163                match callback.call1(py, (token.as_str(), bucket.as_str(), &req)) {
164                    Ok(result_obj) => result_obj
165                        .extract::<bool>(py)
166                        .map_err(|_| "Failed to extract boolean".to_string()),
167                    Err(e) => {
168                        error!("Python callback error: {:?}", e);
169                        Err("Inner Python exception".to_string())
170                    }
171                }
172            })
173        })
174        .await
175        .map_err(|e| format!("Join error: {:?}", e))??
176    } else {
177        task::spawn_blocking(move || {
178            Python::with_gil(
179                |py| match callback.call1(py, (token.as_str(), bucket.as_str())) {
180                    Ok(result_obj) => result_obj
181                        .extract::<bool>(py)
182                        .map_err(|_| "Failed to extract boolean".to_string()),
183                    Err(e) => {
184                        error!("Python callback error: {:?}", e);
185                        Err("Inner Python exception".to_string())
186                    }
187                },
188            )
189        })
190        .await
191        .map_err(|e| format!("Join error: {:?}", e))??
192    };
193
194    Ok(authorized)
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[tokio::test]
202    async fn auth_cache_get_or_validate_behaviors() {
203        let cache = AuthCache::new(AUTH_CACHE_DEFAULT_CAPACITY);
204        let key = "auth_key";
205
206        let calls = Arc::new(Mutex::new(0));
207        let validator = {
208            let calls = Arc::clone(&calls);
209            move || {
210                let calls = Arc::clone(&calls);
211                async move {
212                    let mut calls_lock = calls.lock().await;
213                    *calls_lock += 1;
214                    Ok::<bool, std::convert::Infallible>(true)
215                }
216            }
217        };
218        let res1 = cache
219            .get_or_validate(key, Duration::from_secs(1), validator)
220            .await
221            .unwrap();
222        assert!(res1);
223        assert_eq!(*calls.lock().await, 1);
224
225        // second call within TTL: cache hit, no new call
226        let res2 = cache
227            .get_or_validate(key, Duration::from_secs(1), {
228                let calls = Arc::clone(&calls);
229                move || {
230                    let calls = Arc::clone(&calls);
231                    async move {
232                        let mut calls_lock = calls.lock().await;
233                        *calls_lock += 1;
234                        Ok::<bool, std::convert::Infallible>(false)
235                    }
236                }
237            })
238            .await
239            .unwrap();
240        assert!(res2);
241        assert_eq!(*calls.lock().await, 1);
242
243        // wait for expiry
244        tokio::time::sleep(Duration::from_secs(2)).await;
245        let res3 = cache
246            .get_or_validate(key, Duration::from_secs(1), || async move {
247                Ok::<bool, std::convert::Infallible>(false)
248            })
249            .await
250            .unwrap();
251        assert!(!res3);
252    }
253}