Skip to main content

object_storage_proxy/utils/
validator.rs

1use pyo3::{PyObject, Python};
2use tokio::{sync::Mutex, task};
3use tracing::{debug, error};
4
5use std::{
6    collections::HashMap,
7    sync::{Arc, RwLock},
8    time::{Duration, Instant},
9};
10
11use crate::utils::functions::callable_accepts_request;
12
13#[derive(Clone, Debug)]
14struct AuthEntry {
15    authorized: bool,
16    expires_at: Instant,
17}
18
19/// A time-bounded, per-request-key cache for authorization decisions.
20///
21/// Wraps arbitrary async validator functions so that the (potentially
22/// expensive) Python callback is only invoked once per `(access_key, bucket,
23/// query)` tuple within the configured TTL window.
24///
25/// Concurrent cache misses for the same key are serialized via a per-key
26/// [`Mutex`] to avoid thundering-herd stampedes.
27#[derive(Clone, Debug)]
28pub struct AuthCache {
29    inner: Arc<RwLock<HashMap<String, AuthEntry>>>,
30    locks: Arc<Mutex<HashMap<String, Arc<Mutex<()>>>>>,
31}
32
33impl Default for AuthCache {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39impl AuthCache {
40    pub fn new() -> Self {
41        AuthCache {
42            inner: Arc::new(RwLock::new(HashMap::new())),
43            locks: Arc::new(Mutex::new(HashMap::new())),
44        }
45    }
46
47    pub async fn get_or_validate<F, Fut, E>(
48        &self,
49        key: &str,
50        ttl: Duration,
51        validator_fn: F,
52    ) -> Result<bool, E>
53    where
54        F: Fn() -> Fut + Send + Sync + 'static,
55        Fut: std::future::Future<Output = Result<bool, E>> + Send,
56        E: std::fmt::Debug,
57    {
58        if let Some(entry) = {
59            let map = self.inner.read().unwrap();
60            map.get(key).cloned()
61        } && Instant::now() < entry.expires_at
62        {
63            debug!("Cache hit for key.");
64            return Ok(entry.authorized);
65        }
66        debug!("Cache miss for key. Validating authorization...");
67        let key_lock = {
68            let mut locks_map = self.locks.lock().await;
69            locks_map
70                .entry(key.to_string())
71                .or_insert_with(|| Arc::new(Mutex::new(())))
72                .clone()
73        };
74        let _guard = key_lock.lock().await;
75
76        if let Some(entry) = {
77            let map = self.inner.read().expect("lock poisoned");
78            map.get(key).cloned()
79        } && Instant::now() < entry.expires_at
80        {
81            return Ok(entry.authorized);
82        }
83
84        let decision = validator_fn().await?;
85
86        {
87            let mut map = self.inner.write().expect("lock poisoned");
88            map.insert(
89                key.to_string(),
90                AuthEntry {
91                    authorized: decision,
92                    expires_at: Instant::now() + ttl,
93                },
94            );
95        }
96        debug!("Authorization cache updated for key.");
97        Ok(decision)
98    }
99
100    /// Pre-populate the cache with a known decision for `key`.
101    pub fn insert(&self, key: String, authorized: bool, ttl: Duration) {
102        let entry = AuthEntry {
103            authorized,
104            expires_at: Instant::now() + ttl,
105        };
106        let mut map = self.inner.write().expect("lock poisoned");
107        map.insert(key, entry);
108    }
109
110    /// Evict the cached entry for `key`, forcing re-validation on the next request.
111    pub fn invalidate(&self, key: &str) {
112        let mut map = self.inner.write().expect("lock poisoned");
113        map.remove(key);
114    }
115}
116
117/// Invoke the Python validator callback for a single request.
118///
119/// Inspects the callable's signature at runtime (via Python's `inspect` module)
120/// to determine whether it accepts a third `request: dict` argument.  The call
121/// is dispatched on a [`tokio::task::spawn_blocking`] thread so that the Python
122/// GIL does not block the async runtime.
123///
124/// Returns `Ok(true)` if the callback approves the request, `Ok(false)` if it
125/// denies it, or `Err(String)` on any Python exception.
126pub async fn validate_request(
127    token: &str,
128    bucket: &str,
129    request: &HashMap<String, String>,
130    callback: PyObject,
131) -> Result<bool, String> {
132    let token = token.to_string();
133    let bucket = bucket.to_string();
134
135    let req = request
136        .iter()
137        .map(|(k, v)| (k.clone(), v.clone()))
138        .collect::<HashMap<String, String>>();
139
140    debug!("request details sent to Python callable: {:?}", &req);
141
142    let takes_request = Python::with_gil(|py| {
143        callable_accepts_request(py, &callback)
144            .map_err(|e| format!("Invalid callable signature: {:?}", e))
145    });
146
147    if takes_request.is_err() {
148        return Err(format!("Invalid callable signature: {:?}", takes_request));
149    }
150    let takes_request = takes_request.expect("checked above");
151
152    debug!("Python callable can take request: {:?}", &takes_request);
153
154    let authorized = if takes_request {
155        task::spawn_blocking(move || {
156            Python::with_gil(|py| {
157                match callback.call1(py, (token.as_str(), bucket.as_str(), &req)) {
158                    Ok(result_obj) => result_obj
159                        .extract::<bool>(py)
160                        .map_err(|_| "Failed to extract boolean".to_string()),
161                    Err(e) => {
162                        error!("Python callback error: {:?}", e);
163                        Err("Inner Python exception".to_string())
164                    }
165                }
166            })
167        })
168        .await
169        .map_err(|e| format!("Join error: {:?}", e))??
170    } else {
171        task::spawn_blocking(move || {
172            Python::with_gil(
173                |py| match callback.call1(py, (token.as_str(), bucket.as_str())) {
174                    Ok(result_obj) => result_obj
175                        .extract::<bool>(py)
176                        .map_err(|_| "Failed to extract boolean".to_string()),
177                    Err(e) => {
178                        error!("Python callback error: {:?}", e);
179                        Err("Inner Python exception".to_string())
180                    }
181                },
182            )
183        })
184        .await
185        .map_err(|e| format!("Join error: {:?}", e))??
186    };
187
188    Ok(authorized)
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[tokio::test]
196    async fn auth_cache_get_or_validate_behaviors() {
197        let cache = AuthCache::new();
198        let key = "auth_key";
199
200        let calls = Arc::new(Mutex::new(0));
201        let validator = {
202            let calls = Arc::clone(&calls);
203            move || {
204                let calls = Arc::clone(&calls);
205                async move {
206                    let mut calls_lock = calls.lock().await;
207                    *calls_lock += 1;
208                    Ok::<bool, std::convert::Infallible>(true)
209                }
210            }
211        };
212        let res1 = cache
213            .get_or_validate(key, Duration::from_secs(1), validator)
214            .await
215            .unwrap();
216        assert!(res1);
217        assert_eq!(*calls.lock().await, 1);
218
219        // second call within TTL: cache hit, no new call
220        let res2 = cache
221            .get_or_validate(key, Duration::from_secs(1), {
222                let calls = Arc::clone(&calls);
223                move || {
224                    let calls = Arc::clone(&calls);
225                    async move {
226                        let mut calls_lock = calls.lock().await;
227                        *calls_lock += 1;
228                        Ok::<bool, std::convert::Infallible>(false)
229                    }
230                }
231            })
232            .await
233            .unwrap();
234        assert!(res2);
235        assert_eq!(*calls.lock().await, 1);
236
237        // wait for expiry
238        tokio::time::sleep(Duration::from_secs(2)).await;
239        let res3 = cache
240            .get_or_validate(key, Duration::from_secs(1), || async move {
241                Ok::<bool, std::convert::Infallible>(false)
242            })
243            .await
244            .unwrap();
245        assert!(!res3);
246    }
247}