object_storage_proxy/credentials/
secrets_proxy.rs1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock},
4};
5
6use pyo3::{PyObject, PyResult, Python};
7use reqwest::Client;
8use serde::Deserialize;
9use tracing::{debug, error};
10
11#[derive(Clone, Debug)]
13pub struct SecretValue {
14 value: String,
15 expiration: u64,
16}
17
18impl SecretValue {
19 pub fn new(value: String, expiration: u64) -> Self {
21 SecretValue { value, expiration }
22 }
23
24 pub fn get_value(&self) -> &str {
26 &self.value
27 }
28
29 pub fn get_expiration(&self) -> u64 {
31 self.expiration
32 }
33
34 pub fn is_expired(&self) -> bool {
36 let now = std::time::SystemTime::now()
37 .duration_since(std::time::SystemTime::UNIX_EPOCH)
38 .expect("system clock before Unix epoch")
39 .as_secs();
40 now >= self.expiration - 300 }
42}
43
44#[derive(Clone, Debug)]
50pub struct SecretsCache {
51 inner: Arc<RwLock<HashMap<String, SecretValue>>>,
52}
53
54impl Default for SecretsCache {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl SecretsCache {
61 pub fn new() -> Self {
63 SecretsCache {
64 inner: Arc::new(RwLock::new(HashMap::new())),
65 }
66 }
67
68 pub fn insert(&self, key: String, value: String, expiration: u64) {
70 let secret = SecretValue { value, expiration };
71
72 let mut map = self.inner.write().expect("lock poisoned");
73 map.insert(key, secret);
74 }
75
76 pub async fn get<F, Fut>(&self, key: &str, bearer_fetcher: F) -> Option<String>
79 where
80 F: Fn() -> Fut + Send + 'static,
81 Fut: std::future::Future<Output = Result<IamResponse, Box<dyn std::error::Error>>> + Send,
82 {
83 let maybe_secret = {
84 let map = self.inner.read().expect("lock poisoned");
85 map.get(key).cloned()
86 };
87
88 match maybe_secret {
89 Some(secret) => {
90 if secret.is_expired() {
91 debug!("Token for {} is expired, renewing ...", key);
92 match bearer_fetcher().await {
93 Ok(iam_response) => {
94 self.insert(
95 key.to_string(),
96 iam_response.access_token.clone(),
97 iam_response.expiration,
98 );
99 debug!("Renewed token for {}", key);
100 Some(iam_response.access_token)
101 }
102 Err(e) => {
103 error!("Failed to renew token for {}: {}", key, e);
104 None
105 }
106 }
107 } else {
108 debug!("Using cached token for {}", key);
109 Some(secret.get_value().to_string())
110 }
111 }
112 None => {
113 debug!("No cached token found for {}, fetching ...", key);
114 match bearer_fetcher().await {
115 Ok(iam_response) => {
116 self.insert(
117 key.to_string(),
118 iam_response.access_token.clone(),
119 iam_response.expiration,
120 );
121 debug!("Fetched new token for {}", key);
122 Some(iam_response.access_token)
123 }
124 Err(e) => {
125 error!("Failed to fetch token for {}: {}", key, e);
126 None
127 }
128 }
129 }
130 }
131 }
132
133 pub fn invalidate(&self, key: &str) {
135 let mut map = self.inner.write().expect("lock poisoned");
136 map.remove(key);
137 }
138}
139
140#[derive(Deserialize, Debug)]
142pub struct IamResponse {
143 pub access_token: String,
145 pub expires_in: u32,
147 pub expiration: u64,
149}
150
151pub(crate) async fn get_bearer(api_key: String) -> Result<IamResponse, Box<dyn std::error::Error>> {
156 debug!("Fetching bearer token for the API key");
157 let client = Client::new();
158
159 let params = [
160 ("grant_type", "urn:ibm:params:oauth:grant-type:apikey"),
161 ("apikey", &api_key),
162 ];
163
164 let resp = client
166 .post("https://iam.cloud.ibm.com/identity/token")
167 .header("Content-Type", "application/x-www-form-urlencoded")
168 .form(¶ms)
169 .send()
170 .await?;
171
172 if resp.status().is_success() {
173 let iam_response: IamResponse = resp.json().await?;
174 debug!("Received access token");
175 Ok(iam_response)
176 } else {
177 let err_text = resp.text().await?;
178 error!("Failed to get token: {}", err_text);
179 Err(format!("Failed to get token: {}", err_text).into())
180 }
181}
182
183pub(crate) async fn get_credential_for_bucket(
189 callback: &PyObject,
190 bucket: String,
191 token: String,
192) -> PyResult<String> {
193 Python::with_gil(|py| {
194 let s = callback.call1(py, (token, bucket))?;
195 s.extract::<String>(py)
196 })
197}
198
199#[cfg(test)]
200mod tests {
201 use std::time::{SystemTime, UNIX_EPOCH};
202
203 use super::*;
204 use wiremock::matchers::{method, path};
205 use wiremock::{Mock, MockServer, ResponseTemplate};
206
207 fn make_response(json_body: &str, status_code: u16) -> ResponseTemplate {
208 ResponseTemplate::new(status_code).set_body_raw(json_body, "application/json")
209 }
210
211 #[tokio::test]
212 async fn test_get_bearer_success() {
213 let mock_server = MockServer::start().await;
214
215 let response_body = r#"{
216 "access_token": "mock_access_token",
217 "expires_in": 3600,
218 "expiration": 9999999999
219 }"#;
220
221 Mock::given(method("POST"))
222 .and(path("/identity/token"))
223 .respond_with(make_response(response_body, 200))
224 .mount(&mock_server)
225 .await;
226
227 let result = get_bearer_with_url("mock_api_key".to_string(), &mock_server.uri()).await;
228 assert!(result.is_ok());
229 assert_eq!(result.unwrap(), "mock_access_token");
230 }
231
232 #[tokio::test]
233 async fn test_get_bearer_failure() {
234 let mock_server = MockServer::start().await;
235
236 Mock::given(method("POST"))
237 .and(path("/identity/token"))
238 .respond_with(ResponseTemplate::new(400).set_body_string("Invalid API key"))
239 .mount(&mock_server)
240 .await;
241
242 let result = get_bearer_with_url("invalid_api_key".to_string(), &mock_server.uri()).await;
243 assert!(result.is_err());
244 assert_eq!(
245 result.unwrap_err().to_string(),
246 "Failed to get token: Invalid API key"
247 );
248 }
249
250 #[tokio::test]
251 async fn test_get_bearer_invalid_json() {
252 let mock_server = MockServer::start().await;
253
254 let invalid_json = r#"{
255 "invalid_field": "value"
256 }"#;
257
258 Mock::given(method("POST"))
259 .and(path("/identity/token"))
260 .respond_with(make_response(invalid_json, 200))
261 .mount(&mock_server)
262 .await;
263
264 let result = get_bearer_with_url("mock_api_key".to_string(), &mock_server.uri()).await;
265 assert!(result.is_err());
266
267 let err_message = result.unwrap_err().to_string();
268 assert!(
269 err_message.contains("missing field `access_token`")
270 || err_message.contains("error decoding response body"),
271 "Unexpected error message: {}",
272 err_message
273 );
274 }
275
276 async fn get_bearer_with_url(
277 api_key: String,
278 base_url: &str,
279 ) -> Result<String, Box<dyn std::error::Error>> {
280 let client = reqwest::Client::new();
281 let params = [
282 ("grant_type", "urn:ibm:params:oauth:grant-type:apikey"),
283 ("apikey", &api_key),
284 ];
285 let resp = client
286 .post(&format!("{}/identity/token", base_url))
287 .header("Content-Type", "application/x-www-form-urlencoded")
288 .form(¶ms)
289 .send()
290 .await?;
291
292 if resp.status().is_success() {
293 let iam_response: IamResponse = resp.json().await?;
294 Ok(iam_response.access_token)
295 } else {
296 let err_text = resp.text().await?;
297 Err(format!("Failed to get token: {}", err_text).into())
298 }
299 }
300
301 #[tokio::test]
302 async fn secrets_cache_hit_returns_cached_value() {
303 let cache = SecretsCache::new();
304 let key = "test".to_string();
305
306 let now = SystemTime::now()
307 .duration_since(UNIX_EPOCH)
308 .unwrap()
309 .as_secs();
310 cache.insert(key.clone(), "cached_token".to_string(), now + 3600);
311
312 let fetcher = || async { panic!("Should not be called on cache hit") };
313
314 let result = cache.get(&key, fetcher).await;
315 assert_eq!(result, Some("cached_token".to_string()));
316 }
317
318 #[tokio::test]
319 async fn secrets_cache_expired_renews_token() {
320 let cache = SecretsCache::new();
321 let key = "test2".to_string();
322 let now = SystemTime::now()
324 .duration_since(UNIX_EPOCH)
325 .unwrap()
326 .as_secs();
327 cache.insert(key.clone(), "old_token".to_string(), now);
328
329 let fetcher = move || async {
331 let now = SystemTime::now()
332 .duration_since(UNIX_EPOCH)
333 .unwrap()
334 .as_secs();
335 Ok(IamResponse {
336 access_token: "new_token".into(),
337 expires_in: 3600,
338 expiration: now + 7200,
339 })
340 };
341
342 let result = cache.get(&key, fetcher).await;
343 assert_eq!(result, Some("new_token".to_string()));
344 }
345
346 #[tokio::test]
347 async fn secrets_cache_invalidate_works() {
348 let cache = SecretsCache::new();
349 let key = "test3".to_string();
350 let now = SystemTime::now()
351 .duration_since(UNIX_EPOCH)
352 .unwrap()
353 .as_secs();
354 cache.insert(key.clone(), "token".to_string(), now + 3600);
355
356 cache.invalidate(&key);
357
358 let fetcher = move || async {
360 let now = SystemTime::now()
361 .duration_since(UNIX_EPOCH)
362 .unwrap()
363 .as_secs();
364 Ok(IamResponse {
365 access_token: "fresh_token".into(),
366 expires_in: 3600,
367 expiration: now + 3600,
368 })
369 };
370 let result = cache.get(&key, fetcher).await;
371 assert_eq!(result, Some("fresh_token".to_string()));
372 }
373}