Skip to main content

object_storage_proxy/credentials/
secrets_proxy.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4};
5
6use pyo3::{PyObject, PyResult, Python};
7use reqwest::Client;
8use serde::Deserialize;
9use tracing::{debug, error};
10
11/// A cached IAM bearer token together with its UNIX expiry timestamp.
12#[derive(Clone, Debug)]
13pub struct SecretValue {
14    value: String,
15    expiration: u64,
16}
17
18impl SecretValue {
19    /// Create a new secret with `value` that expires at the given UNIX timestamp.
20    pub fn new(value: String, expiration: u64) -> Self {
21        SecretValue { value, expiration }
22    }
23
24    /// Return the raw token string.
25    pub fn get_value(&self) -> &str {
26        &self.value
27    }
28
29    /// Return the UNIX expiry timestamp.
30    pub fn get_expiration(&self) -> u64 {
31        self.expiration
32    }
33
34    /// Returns `true` if the token has expired (with a 5-minute safety buffer).
35    pub fn is_expired(&self) -> bool {
36        let now = std::time::SystemTime::now()
37            .duration_since(std::time::SystemTime::UNIX_EPOCH)
38            .expect("system clock before Unix epoch")
39            .as_secs();
40        now >= self.expiration - 300 // 5 minute buffer
41    }
42}
43
44/// A thread-safe, in-memory cache for IBM IAM bearer tokens.
45///
46/// Tokens are automatically refreshed when they are within 5 minutes of
47/// expiry.  The cache is keyed by bucket name and shared across all Pingora
48/// worker threads via an [`Arc`].
49#[derive(Clone, Debug)]
50pub struct SecretsCache {
51    inner: Arc<RwLock<HashMap<String, SecretValue>>>,
52}
53
54impl Default for SecretsCache {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl SecretsCache {
61    /// Create a new, empty secrets cache.
62    pub fn new() -> Self {
63        SecretsCache {
64            inner: Arc::new(RwLock::new(HashMap::new())),
65        }
66    }
67
68    /// Insert a token `value` for `key` that expires at the given UNIX timestamp.
69    pub fn insert(&self, key: String, value: String, expiration: u64) {
70        let secret = SecretValue { value, expiration };
71
72        let mut map = self.inner.write().expect("lock poisoned");
73        map.insert(key, secret);
74    }
75
76    /// Return a valid token for `key`, fetching a new one via `bearer_fetcher` if
77    /// the cache is empty or the stored token is near expiry.
78    pub async fn get<F, Fut>(&self, key: &str, bearer_fetcher: F) -> Option<String>
79    where
80        F: Fn() -> Fut + Send + 'static,
81        Fut: std::future::Future<Output = Result<IamResponse, Box<dyn std::error::Error>>> + Send,
82    {
83        let maybe_secret = {
84            let map = self.inner.read().expect("lock poisoned");
85            map.get(key).cloned()
86        };
87
88        match maybe_secret {
89            Some(secret) => {
90                if secret.is_expired() {
91                    debug!("Token for {} is expired, renewing ...", key);
92                    match bearer_fetcher().await {
93                        Ok(iam_response) => {
94                            self.insert(
95                                key.to_string(),
96                                iam_response.access_token.clone(),
97                                iam_response.expiration,
98                            );
99                            debug!("Renewed token for {}", key);
100                            Some(iam_response.access_token)
101                        }
102                        Err(e) => {
103                            error!("Failed to renew token for {}: {}", key, e);
104                            None
105                        }
106                    }
107                } else {
108                    debug!("Using cached token for {}", key);
109                    Some(secret.get_value().to_string())
110                }
111            }
112            None => {
113                debug!("No cached token found for {}, fetching ...", key);
114                match bearer_fetcher().await {
115                    Ok(iam_response) => {
116                        self.insert(
117                            key.to_string(),
118                            iam_response.access_token.clone(),
119                            iam_response.expiration,
120                        );
121                        debug!("Fetched new token for {}", key);
122                        Some(iam_response.access_token)
123                    }
124                    Err(e) => {
125                        error!("Failed to fetch token for {}: {}", key, e);
126                        None
127                    }
128                }
129            }
130        }
131    }
132
133    /// Remove the cached token for `key`, forcing a fresh fetch on the next call to [`get`](Self::get).
134    pub fn invalidate(&self, key: &str) {
135        let mut map = self.inner.write().expect("lock poisoned");
136        map.remove(key);
137    }
138}
139
140/// Deserialized response from the IBM IAM `/identity/token` endpoint.
141#[derive(Deserialize, Debug)]
142pub struct IamResponse {
143    /// The short-lived OAuth2 access token.
144    pub access_token: String,
145    /// Token lifetime in seconds (typically 3600).
146    pub expires_in: u32,
147    /// Absolute UNIX expiry timestamp.
148    pub expiration: u64,
149}
150
151/// Exchange an IBM COS API key for an IAM bearer token.
152///
153/// Posts to `https://iam.cloud.ibm.com/identity/token` and returns the
154/// deserialized [`IamResponse`] on success.
155pub(crate) async fn get_bearer(api_key: String) -> Result<IamResponse, Box<dyn std::error::Error>> {
156    debug!("Fetching bearer token for the API key");
157    let client = Client::new();
158
159    let params = [
160        ("grant_type", "urn:ibm:params:oauth:grant-type:apikey"),
161        ("apikey", &api_key),
162    ];
163
164    // todo: move url to config
165    let resp = client
166        .post("https://iam.cloud.ibm.com/identity/token")
167        .header("Content-Type", "application/x-www-form-urlencoded")
168        .form(&params)
169        .send()
170        .await?;
171
172    if resp.status().is_success() {
173        let iam_response: IamResponse = resp.json().await?;
174        debug!("Received access token");
175        Ok(iam_response)
176    } else {
177        let err_text = resp.text().await?;
178        error!("Failed to get token: {}", err_text);
179        Err(format!("Failed to get token: {}", err_text).into())
180    }
181}
182
183/// Call the Python `bucket_creds_fetcher` callback and return the raw
184/// credential string it produces.
185///
186/// The callback receives `(token, bucket)` as positional arguments and must
187/// return a `str`.
188pub(crate) async fn get_credential_for_bucket(
189    callback: &PyObject,
190    bucket: String,
191    token: String,
192) -> PyResult<String> {
193    Python::with_gil(|py| {
194        let s = callback.call1(py, (token, bucket))?;
195        s.extract::<String>(py)
196    })
197}
198
199#[cfg(test)]
200mod tests {
201    use std::time::{SystemTime, UNIX_EPOCH};
202
203    use super::*;
204    use wiremock::matchers::{method, path};
205    use wiremock::{Mock, MockServer, ResponseTemplate};
206
207    fn make_response(json_body: &str, status_code: u16) -> ResponseTemplate {
208        ResponseTemplate::new(status_code).set_body_raw(json_body, "application/json")
209    }
210
211    #[tokio::test]
212    async fn test_get_bearer_success() {
213        let mock_server = MockServer::start().await;
214
215        let response_body = r#"{
216            "access_token": "mock_access_token",
217            "expires_in": 3600,
218            "expiration": 9999999999
219        }"#;
220
221        Mock::given(method("POST"))
222            .and(path("/identity/token"))
223            .respond_with(make_response(response_body, 200))
224            .mount(&mock_server)
225            .await;
226
227        let result = get_bearer_with_url("mock_api_key".to_string(), &mock_server.uri()).await;
228        assert!(result.is_ok());
229        assert_eq!(result.unwrap(), "mock_access_token");
230    }
231
232    #[tokio::test]
233    async fn test_get_bearer_failure() {
234        let mock_server = MockServer::start().await;
235
236        Mock::given(method("POST"))
237            .and(path("/identity/token"))
238            .respond_with(ResponseTemplate::new(400).set_body_string("Invalid API key"))
239            .mount(&mock_server)
240            .await;
241
242        let result = get_bearer_with_url("invalid_api_key".to_string(), &mock_server.uri()).await;
243        assert!(result.is_err());
244        assert_eq!(
245            result.unwrap_err().to_string(),
246            "Failed to get token: Invalid API key"
247        );
248    }
249
250    #[tokio::test]
251    async fn test_get_bearer_invalid_json() {
252        let mock_server = MockServer::start().await;
253
254        let invalid_json = r#"{
255            "invalid_field": "value"
256        }"#;
257
258        Mock::given(method("POST"))
259            .and(path("/identity/token"))
260            .respond_with(make_response(invalid_json, 200))
261            .mount(&mock_server)
262            .await;
263
264        let result = get_bearer_with_url("mock_api_key".to_string(), &mock_server.uri()).await;
265        assert!(result.is_err());
266
267        let err_message = result.unwrap_err().to_string();
268        assert!(
269            err_message.contains("missing field `access_token`")
270                || err_message.contains("error decoding response body"),
271            "Unexpected error message: {}",
272            err_message
273        );
274    }
275
276    async fn get_bearer_with_url(
277        api_key: String,
278        base_url: &str,
279    ) -> Result<String, Box<dyn std::error::Error>> {
280        let client = reqwest::Client::new();
281        let params = [
282            ("grant_type", "urn:ibm:params:oauth:grant-type:apikey"),
283            ("apikey", &api_key),
284        ];
285        let resp = client
286            .post(&format!("{}/identity/token", base_url))
287            .header("Content-Type", "application/x-www-form-urlencoded")
288            .form(&params)
289            .send()
290            .await?;
291
292        if resp.status().is_success() {
293            let iam_response: IamResponse = resp.json().await?;
294            Ok(iam_response.access_token)
295        } else {
296            let err_text = resp.text().await?;
297            Err(format!("Failed to get token: {}", err_text).into())
298        }
299    }
300
301    #[tokio::test]
302    async fn secrets_cache_hit_returns_cached_value() {
303        let cache = SecretsCache::new();
304        let key = "test".to_string();
305
306        let now = SystemTime::now()
307            .duration_since(UNIX_EPOCH)
308            .unwrap()
309            .as_secs();
310        cache.insert(key.clone(), "cached_token".to_string(), now + 3600);
311
312        let fetcher = || async { panic!("Should not be called on cache hit") };
313
314        let result = cache.get(&key, fetcher).await;
315        assert_eq!(result, Some("cached_token".to_string()));
316    }
317
318    #[tokio::test]
319    async fn secrets_cache_expired_renews_token() {
320        let cache = SecretsCache::new();
321        let key = "test2".to_string();
322        // expired token
323        let now = SystemTime::now()
324            .duration_since(UNIX_EPOCH)
325            .unwrap()
326            .as_secs();
327        cache.insert(key.clone(), "old_token".to_string(), now);
328
329        // fetcher returns new token
330        let fetcher = move || async {
331            let now = SystemTime::now()
332                .duration_since(UNIX_EPOCH)
333                .unwrap()
334                .as_secs();
335            Ok(IamResponse {
336                access_token: "new_token".into(),
337                expires_in: 3600,
338                expiration: now + 7200,
339            })
340        };
341
342        let result = cache.get(&key, fetcher).await;
343        assert_eq!(result, Some("new_token".to_string()));
344    }
345
346    #[tokio::test]
347    async fn secrets_cache_invalidate_works() {
348        let cache = SecretsCache::new();
349        let key = "test3".to_string();
350        let now = SystemTime::now()
351            .duration_since(UNIX_EPOCH)
352            .unwrap()
353            .as_secs();
354        cache.insert(key.clone(), "token".to_string(), now + 3600);
355
356        cache.invalidate(&key);
357
358        // now fetcher must be called
359        let fetcher = move || async {
360            let now = SystemTime::now()
361                .duration_since(UNIX_EPOCH)
362                .unwrap()
363                .as_secs();
364            Ok(IamResponse {
365                access_token: "fresh_token".into(),
366                expires_in: 3600,
367                expiration: now + 3600,
368            })
369        };
370        let result = cache.get(&key, fetcher).await;
371        assert_eq!(result, Some("fresh_token".to_string()));
372    }
373}