Skip to main content

object_storage_proxy/
lib.rs

1//! # object-storage-proxy
2//!
3//! A fast, in-process reverse proxy for **AWS S3**, **IBM Cloud Object Storage (COS)** and other S3-compatible object storage services,
4//! with a Python interface for custom authentication and credential management.
5//!
6//! The proxy is built on top of [Pingora](https://github.com/cloudflare/pingora) and exposed
7//! to Python via [PyO3](https://pyo3.rs). It handles:
8//!
9//! * **AWS Signature Version 4** re-signing — incoming requests are validated and
10//!   then re-signed with backend credentials before being forwarded.
11//! * **Presigned URL enforcement** — optional per-URL usage limits prevent replay abuse.
12//! * **IBM IAM bearer-token exchange** — API keys are automatically exchanged for
13//!   short-lived IAM tokens and cached.
14//! * **Pluggable Python callbacks** — supply an async validator and/or a credential
15//!   fetcher callable from Python to integrate with any auth backend.
16//!
17//! ## Quick start (Python)
18//!
19//! ```python
20//! from object_storage_proxy import ProxyServerConfig, start_server
21//!
22//! config = ProxyServerConfig(
23//!     cos_map={
24//!         "my-bucket": {
25//!             "host": "s3.eu-west-3.amazonaws.com",
26//!             "port": 443,
27//!             "access_key": "AKIAIOSFODNN7EXAMPLE",
28//!             "secret_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY",
29//!             "region": "eu-west-3",
30//!         }
31//!     },
32//!     http_port=6190,
33//! )
34//! start_server(config)
35//! ```
36
37#![warn(clippy::all)]
38use async_trait::async_trait;
39use bytes::{Bytes, BytesMut};
40use credentials::signer::{
41    self, resign_streaming_request, signature_is_valid_for_presigned,
42    signature_is_valid_for_request,
43};
44use dashmap::DashMap;
45use dotenv::dotenv;
46use http::uri::Authority;
47use http::{Method, StatusCode, Uri};
48use parsers::cos_map::{CosMapItem, parse_cos_map};
49use parsers::keystore::parse_hmac_list;
50use pingora::Result;
51use pingora::http::ResponseHeader;
52use pingora::proxy::{ProxyHttp, Session};
53use pingora::server::Server;
54use pingora::upstreams::peer::HttpPeer;
55use pyo3::prelude::*;
56use pyo3::types::{PyModule, PyModuleMethods};
57use pyo3::{Bound, PyResult, Python, pyclass, pyfunction, pymodule, wrap_pyfunction};
58use std::sync::{
59    Arc,
60    atomic::{AtomicBool, AtomicUsize, Ordering},
61};
62
63// use utils::functions::inspect_callable_signature;
64
65use std::collections::HashMap;
66use std::fmt::Debug;
67
68use std::time::Duration;
69use tokio::sync::RwLock;
70use tracing::{debug, error, info, warn};
71use tracing_subscriber::EnvFilter;
72use tracing_subscriber::fmt::time::ChronoLocal;
73
74pub mod parsers;
75use parsers::credentials::{parse_presigned_params, parse_token_from_header};
76use parsers::path::{parse_path, parse_query};
77
78pub mod credentials;
79use credentials::{
80    secrets_proxy::{SecretsCache, get_bearer, get_credential_for_bucket},
81    signer::sign_request,
82};
83
84pub mod utils;
85use utils::banner::print_banner;
86use utils::response::write_error_response_with_header;
87use utils::validator::{AuthCache, validate_request};
88
89static REQ_COUNTER: AtomicUsize = AtomicUsize::new(0);
90static REQ_COUNTER_ENABLED: AtomicBool = AtomicBool::new(false);
91const DEFAULT_SERVER_NAME: &str = "<osp⚡>";
92
93/// Thread-safe hit counter for presigned URLs.
94///
95/// Tracks how many times each presigned URL has been used so that a configurable
96/// maximum can be enforced.  Regular (re-signed) requests are **not** tracked —
97/// the aws-cli issues parallel range-GET sub-requests for the same object, which
98/// would exhaust a small limit instantly.
99///
100/// Internally backed by a [`DashMap`] so that concurrent access from multiple
101/// Pingora worker threads never requires a global lock.
102#[derive(Clone)]
103pub struct UrlTracker {
104    /// Per-URL hit counters.
105    pub counts: Arc<DashMap<String, usize>>,
106}
107
108impl Default for UrlTracker {
109    fn default() -> Self {
110        Self::new()
111    }
112}
113
114impl UrlTracker {
115    /// Create a new, empty tracker.
116    pub fn new() -> Self {
117        UrlTracker {
118            counts: Arc::new(DashMap::new()),
119        }
120    }
121
122    /// Increment the hit counter for `url` by one.
123    pub fn track(&self, url: &str) {
124        let mut entry = self.counts.entry(url.to_string()).or_insert(0);
125        *entry += 1;
126        debug!(url, count = *entry, "tracking presigned URL");
127    }
128
129    /// Return the current hit count for `url`, or `None` if it has never been tracked.
130    pub fn get(&self, url: &str) -> Option<usize> {
131        self.counts.get(url).map(|v| *v)
132    }
133
134    /// Return a snapshot of all tracked URLs and their counts.
135    pub fn get_all(&self) -> Vec<(String, usize)> {
136        self.counts
137            .iter()
138            .map(|e| (e.key().clone(), *e.value()))
139            .collect()
140    }
141}
142
143/// Configuration object for :pyfunc:`object_storage_proxy.start_server`.
144///
145/// Parameters
146/// ----------
147/// cos_map:
148///    A dictionary mapping bucket names to their respective COS configuration.
149///   Each entry should contain the following
150///   keys:
151///   - host: The COS endpoint (e.g., "s3.eu-de.cloud-object-storage.appdomain.cloud")
152///   - port: The port number (e.g., 443)
153///   - api_key/apikey: The API key for the bucket (optional)
154///   - ttl/time-to-live: The time-to-live for the API key in seconds (optional)
155///
156/// bucket_creds_fetcher:
157///     Optional Python async callable that fetches the API key for a bucket.
158///     The callable should accept a single argument, the bucket name.
159///     It should return a string containing the API key.
160/// http_port:
161///     The HTTP port to listen on.
162/// https_port:
163///     The HTTPS port to listen on.
164/// validator:
165///     Optional Python async callable that validates the request.
166///     The callable should accept two arguments, the token and the bucket name.
167///     It should return a boolean indicating whether the request is valid.
168/// threads:
169///     Optional number of threads to use for the server.
170///     If not specified, the server will use a single thread.
171///
172#[pyclass]
173#[pyo3(name = "ProxyServerConfig")]
174#[derive(Debug)]
175pub struct ProxyServerConfig {
176    #[pyo3(get, set)]
177    pub bucket_creds_fetcher: Option<Py<PyAny>>,
178
179    #[pyo3(get, set)]
180    pub cos_map: PyObject,
181
182    #[pyo3(get, set)]
183    pub http_port: Option<u16>,
184
185    #[pyo3(get, set)]
186    pub https_port: Option<u16>,
187
188    #[pyo3(get, set)]
189    pub validator: Option<Py<PyAny>>,
190
191    #[pyo3(get, set)]
192    pub threads: Option<usize>,
193
194    #[pyo3(get, set)]
195    pub verify: Option<bool>,
196
197    #[pyo3(get, set)]
198    pub hmac_keystore: PyObject,
199
200    #[pyo3(get, set)]
201    pub skip_signature_validation: Option<bool>,
202
203    #[pyo3(get, set)]
204    pub hmac_fetcher: Option<Py<PyAny>>,
205
206    #[pyo3(get, set)]
207    pub max_presign_url_usage_attempts: Option<usize>,
208
209    #[pyo3(get, set)]
210    pub server_name: String,
211
212    /// Maximum number of `(access_key, bucket, method)` entries held in the
213    /// in-process authorization cache.  Once the limit is reached the
214    /// least-recently-used entry is evicted automatically.
215    /// Defaults to [`AUTH_CACHE_DEFAULT_CAPACITY`] (1024).
216    #[pyo3(get, set)]
217    pub auth_cache_capacity: Option<usize>,
218
219    /// Port to expose the Prometheus `/metrics` scrape endpoint on.
220    ///
221    /// Only effective when the `metrics` Cargo feature is enabled.
222    /// When `None` (the default) no metrics endpoint is started.
223    #[pyo3(get, set)]
224    pub metrics_port: Option<u16>,
225}
226
227impl Default for ProxyServerConfig {
228    fn default() -> Self {
229        ProxyServerConfig {
230            cos_map: Python::with_gil(|py| py.None()),
231            bucket_creds_fetcher: None,
232            http_port: None,
233            https_port: None,
234            validator: None,
235            threads: Some(1),
236            verify: None,
237            hmac_keystore: Python::with_gil(|py| py.None()),
238            skip_signature_validation: Some(false),
239            hmac_fetcher: None,
240            max_presign_url_usage_attempts: Some(3),
241            server_name: "<osp⚡>".to_string(),
242            auth_cache_capacity: None,
243            metrics_port: None,
244        }
245    }
246}
247
248#[pymethods]
249impl ProxyServerConfig {
250    #[new]
251    #[pyo3(
252        signature = (
253            cos_map,
254            hmac_keystore = None,
255            bucket_creds_fetcher = None,
256            http_port = None,
257            https_port = None,
258            validator = None,
259            threads = Some(1),
260            verify = None,
261            skip_signature_validation = Some(false),
262            hmac_fetcher = None,
263            max_presign_url_usage_attempts = Some(3),
264            server_name = "<osp⚡>".to_string(),
265            auth_cache_capacity = None,
266            metrics_port = None,
267        )
268    )]
269    #[allow(clippy::too_many_arguments)]
270    pub fn new(
271        cos_map: PyObject,
272        hmac_keystore: Option<PyObject>,
273        bucket_creds_fetcher: Option<PyObject>,
274        http_port: Option<u16>,
275        https_port: Option<u16>,
276        validator: Option<PyObject>,
277        threads: Option<usize>,
278        verify: Option<bool>,
279        skip_signature_validation: Option<bool>,
280        hmac_fetcher: Option<PyObject>,
281        max_presign_url_usage_attempts: Option<usize>,
282        server_name: String,
283        auth_cache_capacity: Option<usize>,
284        metrics_port: Option<u16>,
285    ) -> Self {
286        ProxyServerConfig {
287            cos_map,
288            hmac_keystore: hmac_keystore.unwrap_or_else(|| Python::with_gil(|py| py.None())),
289            bucket_creds_fetcher,
290            http_port,
291            https_port,
292            validator,
293            threads,
294            verify,
295            skip_signature_validation,
296            hmac_fetcher,
297            max_presign_url_usage_attempts,
298            server_name,
299            auth_cache_capacity,
300            metrics_port,
301        }
302    }
303
304    fn __repr__(&self) -> PyResult<String> {
305        Ok(format!(
306            "ProxyServerConfig(http_port={}, https_port={}, threads={:?})",
307            self.http_port.unwrap_or(0),
308            self.https_port.unwrap_or(0),
309            self.threads
310        ))
311    }
312}
313
314/// The core Pingora proxy handler.
315///
316/// One instance is created per server and shared (via [`Arc`]) across all worker
317/// threads.  It implements [`ProxyHttp`] and drives the full request lifecycle:
318/// signature validation -> authorization -> credential injection -> upstream routing.
319pub struct MyProxy {
320    cos_endpoint: String,
321    cos_mapping: Arc<RwLock<HashMap<String, CosMapItem>>>,
322    hmac_keystore: Arc<RwLock<HashMap<String, String>>>,
323    secrets_cache: SecretsCache,
324    auth_cache: AuthCache,
325    validator: Option<PyObject>,
326    bucket_creds_fetcher: Option<PyObject>,
327    verify: Option<bool>,
328    skip_signature_validation: Option<bool>,
329    hmac_fetcher: Option<PyObject>,
330    tracker: UrlTracker,
331    max_presign_url_usage_attempts: Option<usize>,
332    #[allow(dead_code)]
333    server_name: String,
334    /// Cached result of `inspect.signature` on the validator callable:
335    /// `true`  = validator accepts a third `request: dict` argument.
336    /// `false` = validator only takes `(token, bucket)`.
337    /// `None`  = no validator configured.
338    validator_takes_request: Option<bool>,
339}
340
341/// Per-request context threaded through the Pingora middleware chain.
342///
343/// A fresh `MyCtx` is created by [`MyProxy::new_ctx`] for every incoming
344/// connection and is discarded when the request completes.
345pub struct MyCtx {
346    cos_mapping: Arc<RwLock<HashMap<String, CosMapItem>>>,
347    hmac_keystore: Arc<RwLock<HashMap<String, String>>>,
348    secrets_cache: SecretsCache,
349    auth_cache: AuthCache,
350    validator: Option<PyObject>,
351    bucket_creds_fetcher: Option<PyObject>,
352    hmac_fetcher: Option<PyObject>,
353    is_presigned: Option<bool>,
354    stream_state: Option<signer::StreamingState>,
355    /// Bucket name parsed from the request path in `request_filter`, reused by
356    /// later stages to avoid redundant `parse_path` calls and map lock acquires.
357    cached_bucket: Option<String>,
358    /// CosMapItem resolved in `request_filter` and reused by `upstream_peer`
359    /// to avoid a second `cos_mapping` RwLock read on every request.
360    /// TODO(perf-2): done — see upstream_peer
361    cached_bucket_config: Option<CosMapItem>,
362}
363
364// impl MyCtx {
365//     fn streaming(&mut self) -> &mut signer::StreamingState {
366//         self.stream_state.as_mut().expect("stream_state not initialised")
367//     }
368// }
369
370#[async_trait]
371impl ProxyHttp for MyProxy {
372    type CTX = MyCtx;
373    fn new_ctx(&self) -> Self::CTX {
374        // Acquire the GIL once to clone all three optional Python callables
375        // rather than paying the GIL acquisition cost up to three times per
376        // new connection.  TODO(perf-5): done
377        let (validator, bucket_creds_fetcher, hmac_fetcher) = Python::with_gil(|py| {
378            (
379                self.validator.as_ref().map(|v| v.clone_ref(py)),
380                self.bucket_creds_fetcher.as_ref().map(|v| v.clone_ref(py)),
381                self.hmac_fetcher.as_ref().map(|v| v.clone_ref(py)),
382            )
383        });
384        MyCtx {
385            cos_mapping: Arc::clone(&self.cos_mapping),
386            hmac_keystore: Arc::clone(&self.hmac_keystore),
387            secrets_cache: self.secrets_cache.clone(),
388            auth_cache: self.auth_cache.clone(),
389            validator,
390            bucket_creds_fetcher,
391            hmac_fetcher,
392            is_presigned: None,
393            stream_state: None,
394            cached_bucket: None,
395            cached_bucket_config: None,
396        }
397    }
398
399    async fn upstream_peer(
400        &self,
401        session: &mut Session,
402        ctx: &mut Self::CTX,
403    ) -> Result<Box<HttpPeer>> {
404        debug!("upstream_peer::start");
405        #[cfg(feature = "metrics")]
406        utils::metrics::ACTIVE_CONNECTIONS.inc();
407        if REQ_COUNTER_ENABLED.load(Ordering::Relaxed) {
408            let new_val = REQ_COUNTER.fetch_add(1, Ordering::Relaxed) + 1;
409            debug!("Request count: {}", new_val);
410        }
411
412        let hdr_bucket = ctx.cached_bucket.clone().unwrap_or_else(|| {
413            let path = session.req_header().uri.path();
414            parse_path(path)
415                .map(|(_, (b, _))| b.to_owned())
416                .unwrap_or_default()
417        });
418
419        // Use the config cached by request_filter; fall back to a fresh lock
420        // read only for the rare case where upstream_peer is called without a
421        // preceding request_filter (e.g. direct Pingora internal calls).
422        let bucket_config = if ctx.cached_bucket_config.is_some() {
423            ctx.cached_bucket_config.clone()
424        } else {
425            let map = ctx.cos_mapping.read().await;
426            map.get(&hdr_bucket).cloned()
427        };
428
429        let addressing_style = bucket_config
430            .as_ref()
431            .and_then(|c| c.addressing_style.as_deref())
432            .unwrap_or("virtual");
433
434        let endpoint = match &bucket_config {
435            Some(config) => {
436                if addressing_style == "path" {
437                    config.host.clone()
438                } else {
439                    format!("{}.{}", hdr_bucket, config.host)
440                }
441            }
442            None => format!("{}.{}", hdr_bucket, self.cos_endpoint),
443        };
444
445        let port = bucket_config.as_ref().map(|c| c.port).unwrap_or(443);
446
447        let addr = (endpoint.clone(), port);
448
449        let endpoint_is_tls = bucket_config.as_ref().and_then(|c| c.tls).unwrap_or(true);
450
451        debug!(endpoint_is_tls, endpoint, "resolved upstream peer");
452
453        let mut peer = Box::new(HttpPeer::new(addr, endpoint_is_tls, endpoint.clone()));
454        debug!(?peer, "upstream peer created");
455
456        // todo: make ths configurable
457
458        peer.options.max_h2_streams = 128;
459        peer.options.h2_ping_interval = Some(Duration::from_secs(30));
460
461        // peer.options.idle_timeout          = Some(Duration::from_secs(300));
462        // peer.options.connection_timeout    = Some(Duration::from_secs(30));
463        // peer.options.read_timeout          = Some(Duration::from_secs(300));
464        // peer.options.write_timeout         = Some(Duration::from_secs(300));
465
466        debug!("peer: {:#?}", &peer);
467
468        if let Some(verify) = self.verify {
469            info!("Verify peer (upstream) certificates disabled!");
470            peer.options.verify_cert = verify;
471            peer.options.verify_hostname = verify;
472        } else {
473            peer.options.verify_cert = true;
474        }
475
476        debug!("peer: {:#?}", &peer);
477
478        debug!("upstream_peer::end");
479        Ok(peer)
480    }
481
482    async fn logging(
483        &self,
484        _session: &mut Session,
485        _e: Option<&pingora::Error>,
486        ctx: &mut Self::CTX,
487    ) {
488        #[cfg(feature = "metrics")]
489        utils::metrics::ACTIVE_CONNECTIONS.dec();
490        let _ = ctx;
491    }
492
493    async fn request_filter(&self, session: &mut Session, ctx: &mut Self::CTX) -> Result<bool> {
494        debug!("request_filter::start");
495
496        // Tracking the request count for presigned URLs only.
497        // Regular (re-signed) requests must not be counted — aws-cli issues multiple
498        // parallel range-GET requests for the same object (multipart download), so
499        // counting every request would exhaust the limit almost immediately.
500        let url = session.req_header().uri.to_string();
501        let path = session.req_header().uri.path().to_string();
502        let is_presigned_url = session
503            .req_header()
504            .uri
505            .query()
506            .is_some_and(|q| q.contains("X-Amz-Signature"));
507        if is_presigned_url {
508            self.tracker.track(&url);
509        }
510        let tracked_count = self.tracker.get(&url).unwrap_or(0);
511        if is_presigned_url && tracked_count > self.max_presign_url_usage_attempts.unwrap_or(3) {
512            #[cfg(feature = "metrics")]
513            {
514                let bucket_label = session
515                    .req_header()
516                    .uri
517                    .path()
518                    .split('/')
519                    .nth(1)
520                    .unwrap_or("-");
521                utils::metrics::PRESIGNED_URL_REJECTED_TOTAL
522                    .with_label_values(&[bucket_label])
523                    .inc();
524            }
525            warn!(
526                url,
527                tracked_count,
528                max = self.max_presign_url_usage_attempts.unwrap_or(3),
529                "presigned URL usage limit exceeded, denying"
530            );
531            let msg = format!(
532                "URL ({}) has been tracked too many times: {} (max={}).  Access Denied!",
533                path,
534                tracked_count,
535                self.max_presign_url_usage_attempts.unwrap_or(3)
536            );
537
538            // let mut hdr = ResponseHeader::build(StatusCode::FORBIDDEN, Some(msg.len()))?;
539            // hdr.insert_header("content-type", "text/plain")?;
540            // hdr.insert_header("server", self.server_name.clone())?;
541            // hdr.insert_header("x-content-type-options", "nosniff")?;
542
543            // // Send it
544            // session.write_response_header(Box::new(hdr), false).await?;
545            // // session
546            // //     .write_response_body(Some(msg.into()), true)
547            // //     .await?;
548
549            // session.respond_error_with_body(403, msg.into()).await?;
550            write_error_response_with_header(session, StatusCode::FORBIDDEN, msg).await?;
551            return Ok(true);
552        }
553
554        debug!(summary = ?session.request_summary(), "request summary");
555        debug!(uri = ?session.req_header().uri, "incoming request URI");
556        debug!("request path: {}", session.req_header().uri.path());
557        debug!("request method: {}", session.req_header().method);
558
559        if session
560            .req_header()
561            .headers
562            .get("expect")
563            .map(|v| {
564                v.to_str()
565                    .unwrap_or("")
566                    .eq_ignore_ascii_case("100-continue")
567            })
568            .unwrap_or(false)
569        {
570            return Ok(false);
571        };
572
573        let path = session.req_header().uri.path().to_owned();
574
575        // ── ListBuckets short-circuit ────────────────────────────────────────────
576        // GET / has no bucket component; parse_path would error.  Return the list
577        // of buckets that are configured in the cos_mapping.
578        if path == "/" && session.req_header().method == Method::GET {
579            let bucket_names: Vec<String> = {
580                let map = ctx.cos_mapping.read().await;
581                let mut names: Vec<String> = map.keys().cloned().collect();
582                names.sort();
583                names
584            };
585            let entries: String = bucket_names
586                .iter()
587                .map(|n| {
588                    format!(
589                        "<Bucket><Name>{n}</Name>\
590<CreationDate>2000-01-01T00:00:00.000Z</CreationDate></Bucket>"
591                    )
592                })
593                .collect();
594            let body = format!(
595                "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\
596<ListAllMyBucketsResult xmlns=\"http://s3.amazonaws.com/doc/2006-03-01/\">\
597<Owner><ID>proxy</ID><DisplayName>proxy</DisplayName></Owner>\
598<Buckets>{entries}</Buckets>\
599</ListAllMyBucketsResult>"
600            );
601            let body_bytes = Bytes::copy_from_slice(body.as_bytes());
602            let mut hdr = ResponseHeader::build(StatusCode::OK, None)?;
603            hdr.insert_header("Content-Type", "application/xml")?;
604            hdr.insert_header("Content-Length", body_bytes.len().to_string())?;
605            hdr.insert_header("Server", DEFAULT_SERVER_NAME)?;
606            session.write_response_header(Box::new(hdr), false).await?;
607            session.write_response_body(Some(body_bytes), true).await?;
608            return Ok(true);
609        }
610
611        let parse_path_result = parse_path(&path);
612        if parse_path_result.is_err() {
613            error!("Failed to parse path: {:?}", parse_path_result);
614            return Err(pingora::Error::new_str("Failed to parse path"));
615        }
616
617        let (_, (bucket, _uri_path)) = parse_path_result.expect("checked above");
618
619        let hdr_bucket = bucket.to_owned();
620        ctx.cached_bucket = Some(hdr_bucket.clone());
621
622        #[cfg(feature = "metrics")]
623        {
624            let method_label = session.req_header().method.as_str();
625            utils::metrics::REQUESTS_TOTAL
626                .with_label_values(&[method_label, &hdr_bucket, "received"])
627                .inc();
628            if is_presigned_url {
629                utils::metrics::PRESIGNED_URL_HITS_TOTAL
630                    .with_label_values(&[&hdr_bucket])
631                    .inc();
632            }
633        }
634
635        let auth_header = session
636            .req_header()
637            .headers
638            .get("authorization")
639            .and_then(|h| h.to_str().ok())
640            .map(ToString::to_string)
641            .unwrap_or_default();
642
643        let (ttl, bucket_config_init) = {
644            let map = ctx.cos_mapping.read().await;
645            let cfg = map.get(bucket).cloned();
646            let ttl = cfg.as_ref().and_then(|c| c.ttl).unwrap_or(0);
647            (ttl, cfg)
648        };
649        // Cache the resolved config so upstream_peer can skip a second lock read.
650        ctx.cached_bucket_config = bucket_config_init.clone();
651        let mut access_key: String = String::new();
652
653        if auth_header.is_empty() {
654            if let Some(q) = session.req_header().uri.query()
655                && q.contains("X-Amz-Credential")
656            {
657                let (_, p) = parse_presigned_params(&format!("?{q}"))
658                    .map_err(|_| pingora::Error::new_str("Failed to parse presigned params"))?;
659                access_key = p.access_key.clone();
660            }
661        } else {
662            access_key = parse_token_from_header(&auth_header)
663                .map_err(|_| pingora::Error::new_str("Failed to parse access_key"))?
664                .1
665                .to_string();
666        }
667
668        let is_authorized = if let Some(py_cb) = &ctx.validator {
669            let is_multipart = session
670                .req_header()
671                .uri
672                .query()
673                .is_some_and(|q| q.contains("uploadId="));
674
675            debug!("checking signature");
676            if let Some(skip) = self.skip_signature_validation {
677                if skip || is_multipart {
678                    debug!("Skipping local signature check");
679                    // continue
680                } else {
681                    // presigned
682                    debug!("Checking presigned signature");
683                    let uri_q = session.req_header().uri.query().unwrap_or("");
684
685                    if auth_header.is_empty() && uri_q.contains("X-Amz-Signature") {
686                        ctx.is_presigned = Some(true);
687
688                        // ensure we have the secret_key in the keystore
689                        if !ctx.hmac_keystore.read().await.contains_key(&access_key) {
690                            debug!(
691                                "No key in keystore, trying to fetch via hmac_fetcher for ->{}<-",
692                                access_key
693                            );
694                            // fetch via hmac_fetcher exactly as you do below…
695                            if let Some(py_fetcher) = &ctx.hmac_fetcher {
696                                // call Python callback
697                                let cb = py_fetcher;
698                                let secret: PyResult<String> = Python::with_gil(|py| {
699                                    cb.call1(py, (&access_key,)).and_then(|r| r.extract(py))
700                                });
701                                debug!("Got secret: {:#?}", secret);
702                                match secret {
703                                    Ok(secret_key) => {
704                                        debug!("got key and inserting into keystore");
705                                        ctx.hmac_keystore
706                                            .write()
707                                            .await
708                                            .insert(access_key.clone(), secret_key);
709                                    }
710                                    Err(_) => {
711                                        // no key -> unauthorized
712                                        write_error_response_with_header(
713                                            session,
714                                            StatusCode::UNAUTHORIZED,
715                                            "No key found for presigned URL".to_string(),
716                                        )
717                                        .await?;
718                                        // session.respond_error(401).await?;
719                                        return Ok(true);
720                                    }
721                                }
722                            } else {
723                                // session.respond_error(401).await?;
724                                write_error_response_with_header(
725                                    session,
726                                    StatusCode::UNAUTHORIZED,
727                                    "No key found for presigned URL".to_string(),
728                                )
729                                .await?;
730                                return Ok(true);
731                            }
732                        }
733                        debug!("now checking if the signature is valid for presigned...");
734                        let sk = ctx
735                            .hmac_keystore
736                            .read()
737                            .await
738                            .get(&access_key)
739                            .expect("key was just inserted")
740                            .clone();
741                        debug!("got secret {} from keystore", sk);
742                        debug!("RAW_PATH       = {}", &session.req_header().uri);
743                        debug!(
744                            "RAW_HOST_HDR   = {:?}",
745                            &session.req_header().headers.get("host")
746                        );
747                        let presigned_result = signature_is_valid_for_presigned(session, &sk)
748                            .await
749                            .map_err(|e| e.to_string());
750                        let ok = match presigned_result {
751                            Ok(b) => b,
752                            Err(msg) => {
753                                error!("presigned-URL validation error: {msg}");
754                                if msg.contains("expired") {
755                                    write_error_response_with_header(
756                                        session,
757                                        StatusCode::FORBIDDEN,
758                                        format!(
759                                            "Presigned URL has expired: {}",
760                                            session.req_header().uri.path()
761                                        ),
762                                    )
763                                    .await?;
764                                    return Ok(true);
765                                }
766                                return Err(pingora::Error::new_str("Failed to check signature"));
767                            }
768                        };
769                        debug!("is signature valid?: {}", ok);
770                        if !ok {
771                            let msg = format!(
772                                "Signature invalid for presigned URL: {}",
773                                &session.req_header().uri.path()
774                            );
775                            session.respond_error_with_body(401, msg.into()).await?;
776                            return Ok(true);
777                        }
778                    } else {
779                        debug!("processing a regular request");
780
781                        let has_key = {
782                            let map = ctx.hmac_keystore.read().await;
783                            map.contains_key(&access_key)
784                        };
785                        if !has_key {
786                            if let Some(py_fetcher) = &ctx.hmac_fetcher {
787                                // call Python callback
788                                let cb = py_fetcher;
789                                let secret: PyResult<String> = Python::with_gil(|py| {
790                                    cb.call1(py, (&access_key,)).and_then(|r| r.extract(py))
791                                });
792                                match secret {
793                                    Ok(secret_key) => {
794                                        ctx.hmac_keystore
795                                            .write()
796                                            .await
797                                            .insert(access_key.clone(), secret_key);
798                                    }
799                                    Err(_) => {
800                                        // no key -> unauthorized
801                                        // session.respond_error(401).await?;
802                                        write_error_response_with_header(
803                                            session,
804                                            StatusCode::UNAUTHORIZED,
805                                            "No key found for request".to_string(),
806                                        )
807                                        .await?;
808                                        return Ok(true);
809                                    }
810                                }
811                            } else {
812                                // session.respond_error(401).await?;
813                                write_error_response_with_header(
814                                    session,
815                                    StatusCode::UNAUTHORIZED,
816                                    "No key found for request".to_string(),
817                                )
818                                .await?;
819                                return Ok(true);
820                            }
821                        }
822                        let secret_key = {
823                            let map = ctx.hmac_keystore.read().await;
824                            map.get(&access_key).cloned()
825                        };
826
827                        debug!("checking signature");
828                        let sig_ok = match signature_is_valid_for_request(
829                            &auth_header,
830                            session,
831                            &secret_key.expect("key was just inserted"),
832                        )
833                        .await
834                        {
835                            Ok(true) => true,
836                            Ok(false) => {
837                                debug!("Signature invalid");
838                                false
839                            }
840                            Err(err) => {
841                                error!("Signature check error: {}", err);
842                                false
843                            }
844                        };
845
846                        // if signature failed, skip further validation
847                        if !sig_ok {
848                            //  session.respond_error(401).await?;
849                            write_error_response_with_header(
850                                session,
851                                StatusCode::UNAUTHORIZED,
852                                "Signature invalid".to_string(),
853                            )
854                            .await?;
855                            return Ok(true);
856                        }
857                    }
858                }
859            }
860            debug!("Signature check passed, continuing now onto the bespoke validation");
861            // Build the query dict here — deferred so requests without a validator
862            // pay no parsing cost at all.
863            let request_query = session.req_header().uri.query().unwrap_or("");
864            let (_, mut query_dict) = parse_query(request_query).map_err(|e| {
865                error!("Failed to parse query: {:?}", e);
866                pingora::Error::new_str("Failed to parse query")
867            })?;
868            query_dict.insert(
869                "method".to_string(),
870                session.req_header().method.to_string(),
871            );
872            query_dict.insert(
873                "path".to_string(),
874                session.req_header().uri.path().to_string(),
875            );
876            query_dict.insert(
877                "source".to_string(),
878                session
879                    .req_header()
880                    .headers
881                    .get("x-forwarded-for")
882                    .and_then(|h| h.to_str().ok())
883                    .unwrap_or_default()
884                    .to_string(),
885            );
886            debug!("Parsed query: {:#?}", query_dict);
887            // Cache key: access_key + bucket + HTTP method only.
888            // Volatile query params (uploadId, X-Amz-Date, etc.) must NOT be
889            // included — they differ on every request and would make the cache
890            // useless.
891            let method_str = session.req_header().method.as_str();
892            let cache_key = format!("{}:{}:{}", &access_key, bucket, method_str);
893            debug!("Cache key: {}", cache_key);
894
895            // Default 300-second TTL so the cache is always effective.
896            // ttl=0 in the bucket config means "use default", not "disable caching".
897            // Set ttl to u64::MAX in the bucket config to opt out of expiry.
898            let effective_ttl = Duration::from_secs(if ttl == 0 { 300 } else { ttl });
899
900            let bucket_clone = bucket.to_string();
901            let callback_clone: PyObject = Python::with_gil(|py| py_cb.clone_ref(py));
902
903            let move_access_key = access_key.clone();
904            let req = query_dict.clone();
905            let takes_request = self.validator_takes_request.unwrap_or(false);
906
907            ctx.auth_cache
908                .get_or_validate(&cache_key, effective_ttl, move || {
909                    let tk = move_access_key.clone();
910                    let bu = bucket_clone.clone();
911                    let cb = Python::with_gil(|py| callback_clone.clone_ref(py));
912                    {
913                        let req_value = req.clone();
914                        async move {
915                            validate_request(&tk, &bu, &req_value, cb, takes_request)
916                                .await
917                                .map_err(|_| pingora::Error::new_str("Validator error"))
918                        }
919                    }
920                })
921                .await?
922        } else {
923            true
924        };
925
926        if !is_authorized {
927            warn!("Access denied for bucket: {}.  End of request.", bucket);
928            // session.respond_error(401).await?;
929            write_error_response_with_header(
930                session,
931                StatusCode::UNAUTHORIZED,
932                format!("Access denied for bucket: {}", bucket),
933            )
934            .await?;
935            return Ok(true);
936        }
937
938        let bucket_config = bucket_config_init;
939
940        debug!("Access key: {}", &access_key);
941
942        // we have to check for some available credentials here to be able to return unauthorized already if not
943        match bucket_config.clone() {
944            Some(mut config) => {
945                let fetcher_opt = ctx.bucket_creds_fetcher.as_ref().map(|py_cb| {
946                    // clone the PyObject so the async block is 'static
947                    let cb = Python::with_gil(|py| py_cb.clone_ref(py));
948                    move |bucket: String| async move {
949                        get_credential_for_bucket(&cb, bucket, access_key)
950                            .await
951                            .map_err(|e| e.into()) // Convert PyErr -> Box<dyn Error>
952                    }
953                });
954
955                config
956                    .ensure_credentials(&hdr_bucket, fetcher_opt)
957                    .await
958                    .map_err(|e| {
959                        error!("Credential check failed for {hdr_bucket}: {e}");
960                        pingora::Error::new_str("Credential check failed")
961                    })?;
962
963                ctx.cos_mapping
964                    .write()
965                    .await
966                    .insert(hdr_bucket.clone(), config);
967            }
968            None => {
969                warn!("No configuration for bucket '{hdr_bucket}'; returning 404");
970                // Build an S3-style NoSuchBucket error.  HEAD requests must not
971                // include a body per HTTP spec, so we only write one for others.
972                let mut hdr = ResponseHeader::build(StatusCode::NOT_FOUND, None)?;
973                hdr.insert_header("Server", DEFAULT_SERVER_NAME)?;
974                if session.req_header().method == Method::HEAD {
975                    session.write_response_header(Box::new(hdr), true).await?;
976                } else {
977                    let xml = format!(
978                        "<?xml version=\"1.0\" encoding=\"UTF-8\"?>\
979<Error><Code>NoSuchBucket</Code>\
980<Message>The specified bucket does not exist</Message>\
981<BucketName>{hdr_bucket}</BucketName></Error>"
982                    );
983                    let xml_bytes = Bytes::copy_from_slice(xml.as_bytes());
984                    hdr.insert_header("Content-Type", "application/xml")?;
985                    hdr.insert_header("Content-Length", xml_bytes.len().to_string())?;
986                    session.write_response_header(Box::new(hdr), false).await?;
987                    session.write_response_body(Some(xml_bytes), true).await?;
988                }
989                return Ok(true);
990            }
991        }
992        debug!(
993            "request_filter::Credentials checked for bucket: {}. End of function.",
994            hdr_bucket
995        );
996        debug!("request_filter::end");
997        Ok(false)
998    }
999
1000    async fn upstream_request_filter(
1001        &self,
1002        _session: &mut Session,
1003        upstream_request: &mut pingora::http::RequestHeader,
1004        ctx: &mut Self::CTX,
1005    ) -> Result<()> {
1006        if let Some(presigned) = ctx.is_presigned
1007            && presigned
1008        {
1009            debug!("upstream_request_filter::presigned");
1010            let cleaned_q = upstream_request
1011                .uri
1012                .query()
1013                .unwrap_or("")
1014                .split('&')
1015                .filter(|kv| !kv.starts_with("X-Amz-"))
1016                .collect::<Vec<_>>()
1017                .join("&");
1018
1019            let _ = upstream_request.remove_header("authorization");
1020
1021            let new_path_and_query = if cleaned_q.is_empty() {
1022                upstream_request.uri.path().to_owned()
1023            } else {
1024                format!("{}?{}", upstream_request.uri.path(), cleaned_q)
1025            };
1026
1027            upstream_request.set_uri(
1028                new_path_and_query
1029                    .try_into()
1030                    .map_err(|_| pingora::Error::new_str("invalid URI after query rewrite"))?,
1031            );
1032        }
1033
1034        let _ = upstream_request.remove_header("accept-encoding");
1035
1036        debug!("upstream_request_filter::start");
1037
1038        let (_, (bucket, my_updated_url)) = parse_path(upstream_request.uri.path())
1039            .map_err(|_| pingora::Error::new_str("failed to parse upstream request path"))?;
1040
1041        debug!(my_updated_url, "parsed upstream path");
1042
1043        let hdr_bucket = bucket.to_string();
1044
1045        let my_query = match upstream_request.uri.query() {
1046            Some(q) if !q.is_empty() => format!("?{}", q),
1047            _ => String::new(),
1048        };
1049
1050        let bucket_config = {
1051            let map = ctx.cos_mapping.read().await;
1052            map.get(&hdr_bucket).cloned()
1053        };
1054
1055        let addressing_style = bucket_config
1056            .as_ref()
1057            .and_then(|c| c.addressing_style.as_deref())
1058            .unwrap_or("virtual");
1059
1060        let this_url = match addressing_style {
1061            "virtual" => my_updated_url,
1062            _ => {
1063                // For bucket-root requests, my_updated_url is "/" which would
1064                // produce "/bucket/" (with trailing slash).  Bucket-level S3
1065                // operations (ListObjects, ListMultipartUploads, …) must be
1066                // addressed as "/bucket" — without the trailing slash.
1067                let u_url = if my_updated_url == "/" {
1068                    format!("/{}", bucket)
1069                } else {
1070                    format!("/{}{}", bucket, my_updated_url)
1071                };
1072                debug!(u_url, "using path addressing style");
1073                &u_url.clone()
1074            }
1075        };
1076
1077        let endpoint = match &bucket_config {
1078            Some(cfg) => {
1079                let this_host = match addressing_style {
1080                    "path" => cfg.host.clone(),
1081                    _ => format!("{}.{}", bucket, cfg.host),
1082                };
1083                if cfg.port == 443 {
1084                    this_host
1085                } else {
1086                    format!("{}:{}", this_host, cfg.port)
1087                }
1088            }
1089            None => format!("{}.{}", bucket, self.cos_endpoint),
1090        };
1091
1092        debug!("endpoint: {}.", &endpoint);
1093
1094        let authority = Authority::try_from(endpoint.as_str())
1095            .map_err(|_| pingora::Error::new_str("invalid upstream authority"))?;
1096        // if addressing_style == "virtual" {
1097
1098        let new_uri = Uri::builder()
1099            .scheme("https")
1100            .authority(authority.clone())
1101            .path_and_query(this_url.to_owned() + &my_query)
1102            .build()
1103            .expect("should build a valid URI");
1104
1105        upstream_request.set_uri(new_uri.clone());
1106        // }
1107        upstream_request.insert_header("host", authority.as_str())?;
1108
1109        let (maybe_hmac, maybe_api_key) = match &bucket_config {
1110            Some(cfg) => (cfg.has_hmac(), cfg.api_key.clone()),
1111            None => (false, None),
1112        };
1113
1114        let allowed = [
1115            "host",
1116            "content-length",
1117            "content-type",
1118            "content-md5",
1119            "x-amz-date",
1120            "x-amz-content-sha256",
1121            "x-amz-security-token",
1122            "transfer-encoding",
1123            "content-encoding",
1124            "x-amz-decoded-content-length",
1125            "x-amz-trailer",
1126            "x-amz-sdk-checksum-algorithm",
1127            // CopyObject headers
1128            "x-amz-copy-source",
1129            "x-amz-metadata-directive",
1130            "x-amz-copy-source-if-match",
1131            "x-amz-copy-source-if-none-match",
1132            "x-amz-copy-source-if-modified-since",
1133            "x-amz-copy-source-if-unmodified-since",
1134            // UploadPartCopy byte-range
1135            "x-amz-copy-source-range",
1136            // Conditional GET/PUT (If-Match, If-None-Match, etc.)
1137            "if-match",
1138            "if-none-match",
1139            "if-modified-since",
1140            "if-unmodified-since",
1141            // User-visible metadata that Garage stores and returns verbatim
1142            "cache-control",
1143            "content-disposition",
1144            // Inline object tagging on PutObject
1145            "x-amz-tagging",
1146            // TaggingDirective on CopyObject (COPY or REPLACE)
1147            "x-amz-tagging-directive",
1148            "range",
1149            "expect",
1150        ];
1151
1152        let to_remove: Vec<String> = upstream_request
1153            .headers
1154            .iter()
1155            .filter_map(|(name, _)| {
1156                let n = name.as_str();
1157                let keep = allowed.contains(&n)
1158                    || n.starts_with("x-amz-checksum-")
1159                    || n.starts_with("x-amz-meta-");
1160                if keep { None } else { Some(n.to_owned()) }
1161            })
1162            .collect();
1163
1164        for name in to_remove {
1165            let _ = upstream_request.remove_header(&name);
1166        }
1167
1168        if maybe_hmac {
1169            debug!("HMAC: Signing request for bucket: {}", hdr_bucket);
1170
1171            let streaming = {
1172                upstream_request
1173                    .headers
1174                    .get("x-amz-content-sha256")
1175                    .map(|v| v.as_bytes().starts_with(b"STREAMING-"))
1176                    .unwrap_or(false)
1177            };
1178
1179            if streaming {
1180                let streaming_header = upstream_request
1181                    .headers
1182                    .get("x-amz-content-sha256")
1183                    .and_then(|v| v.to_str().ok())
1184                    .unwrap_or_default();
1185
1186                debug!(streaming_header, "streaming upload detected");
1187
1188                let cfg = bucket_config.as_ref().ok_or_else(|| {
1189                    pingora::Error::new_str("no bucket config for streaming upload")
1190                })?;
1191                let access_key = cfg.access_key.as_deref().unwrap_or_default().to_string();
1192                let secret_key = cfg.secret_key.as_deref().unwrap_or_default().to_string();
1193                let region = cfg.region.as_deref().unwrap_or_default().to_string();
1194
1195                // let decoded_len = upstream_request
1196                //     .headers
1197                //     .get("x-amz-decoded-content-length")
1198                //     .and_then(|v| v.to_str().ok())
1199                //     .unwrap_or("0")
1200                //     .to_owned();
1201
1202                // remove the original streaming headers we cannot forward.
1203                // upstream_request.remove_header("x-amz-decoded-content-length");
1204
1205                //  stream-chunk.
1206                debug!(headers = ?upstream_request.headers, "upstream request headers before streaming rewrite");
1207                upstream_request.remove_header("content-length");
1208                upstream_request.remove_header("content-md5");
1209                upstream_request.insert_header("transfer-encoding", "chunked")?;
1210                // upstream_request.insert_header("x-amz-decoded-content-length", decoded_len)?;
1211                upstream_request.set_send_end_stream(false);
1212
1213                // produce *seed* signature and signing key that will be reused
1214                //    for every DATA frame in the forthcoming request_body_filter.
1215                let ts = chrono::Utc::now();
1216                resign_streaming_request(upstream_request, &region, &access_key, &secret_key, ts)
1217                    .map_err(|e| {
1218                    error!("Failed to sign request: {e}");
1219                    pingora::Error::new_str("Failed to sign request")
1220                })?;
1221
1222                let seed_sig = upstream_request
1223                    .headers
1224                    .get("authorization")
1225                    .and_then(|v| v.to_str().ok())
1226                    .and_then(|v| v.split("Signature=").nth(1))
1227                    .expect("seed signature missing")
1228                    .to_owned();
1229
1230                // stash everything the body filter will need.
1231                ctx.stream_state = Some(signer::StreamingState::new(
1232                    region.to_string(),
1233                    access_key.to_string(),
1234                    secret_key.to_string(),
1235                    ts,
1236                    seed_sig,
1237                ));
1238            } else {
1239                sign_request(
1240                    upstream_request,
1241                    bucket_config
1242                        .as_ref()
1243                        .ok_or_else(|| pingora::Error::new_str("no bucket config for signing"))?,
1244                )
1245                .await
1246                .map_err(|e| {
1247                    error!("Failed to sign request for {}: {e}", hdr_bucket);
1248                    pingora::Error::new_str("Failed to sign request")
1249                })?;
1250            }
1251
1252            debug!("Request signed for bucket: {}", hdr_bucket);
1253            debug!("{:#?}", &upstream_request.headers);
1254        } else {
1255            debug!("Using API key for bucket: {}", hdr_bucket);
1256            let api_key = match maybe_api_key {
1257                Some(key) => key,
1258                None => {
1259                    // should be impossible because request_filter already
1260                    // called ensure_credentials, but double‑check anyway
1261                    error!("No API key for bucket {hdr_bucket}");
1262                    return Err(pingora::Error::new_str("No API key configured for bucket"));
1263                }
1264            };
1265
1266            // closure captured by SecretsCache
1267            let bearer_fetcher = {
1268                let api_key = api_key.clone();
1269                move || get_bearer(api_key.clone())
1270            };
1271
1272            let bearer_token = ctx
1273                .secrets_cache
1274                .get(&hdr_bucket, bearer_fetcher)
1275                .await
1276                .ok_or_else(|| pingora::Error::new_str("Failed to obtain bearer token"))?;
1277
1278            upstream_request.insert_header("Authorization", format!("Bearer {bearer_token}"))?;
1279        }
1280
1281        // debug!("Sending request to upstream: {}", &new_uri);
1282
1283        debug!("Request sent to upstream.");
1284        debug!("upstream_request_filter::end");
1285
1286        Ok(())
1287    }
1288
1289    async fn response_filter(
1290        &self,
1291        #[cfg_attr(not(feature = "metrics"), allow(unused_variables))] session: &mut Session,
1292        resp: &mut ResponseHeader,
1293        _ctx: &mut Self::CTX,
1294    ) -> Result<()> {
1295        let _ = resp.remove_header("Server");
1296        let _ = resp.insert_header("Server", DEFAULT_SERVER_NAME);
1297
1298        // S3 guarantees objects always have a Content-Type.  When the backend
1299        // (Garage) omits it (e.g. object stored without an explicit type), we
1300        // fall back to the S3 default so clients don't see a missing header.
1301        if resp.headers.get("content-type").is_none()
1302            && (resp.status == StatusCode::OK || resp.status == StatusCode::PARTIAL_CONTENT)
1303        {
1304            let _ = resp.insert_header("Content-Type", "application/octet-stream");
1305        }
1306
1307        #[cfg(feature = "metrics")]
1308        {
1309            let status = resp.status.as_str();
1310            let method = session.req_header().method.as_str();
1311            let bucket = session
1312                .req_header()
1313                .uri
1314                .path()
1315                .split('/')
1316                .nth(1)
1317                .unwrap_or("-");
1318            utils::metrics::REQUESTS_TOTAL
1319                .with_label_values(&[method, bucket, status])
1320                .inc();
1321            if resp.status.is_client_error() || resp.status.is_server_error() {
1322                utils::metrics::REQUEST_ERRORS_TOTAL
1323                    .with_label_values(&[method, bucket, status])
1324                    .inc();
1325            }
1326            if let Some(cl) = resp
1327                .headers
1328                .get("content-length")
1329                .and_then(|v| v.to_str().ok())
1330                .and_then(|v| v.parse::<i64>().ok())
1331            {
1332                utils::metrics::TRANSFER_BYTES_TOTAL
1333                    .with_label_values(&["tx", bucket])
1334                    .inc_by(cl as u64);
1335                utils::metrics::RESPONSE_SIZE_BYTES
1336                    .with_label_values(&[method, bucket])
1337                    .observe(cl as f64);
1338            }
1339        }
1340
1341        Ok(())
1342    }
1343
1344    async fn request_body_filter(
1345        &self,
1346        _session: &mut Session,
1347        body: &mut Option<bytes::Bytes>,
1348        end_of_stream: bool,
1349        ctx: &mut Self::CTX,
1350    ) -> Result<()> {
1351        // 0. Track inbound bytes regardless of streaming state
1352        #[cfg(feature = "metrics")]
1353        if let Some(payload) = body.as_ref()
1354            && !payload.is_empty()
1355        {
1356            let bucket = _session
1357                .req_header()
1358                .uri
1359                .path()
1360                .split('/')
1361                .nth(1)
1362                .unwrap_or("-");
1363            utils::metrics::TRANSFER_BYTES_TOTAL
1364                .with_label_values(&["rx", bucket])
1365                .inc_by(payload.len() as u64);
1366        }
1367
1368        // 1. Only active when we stashed a StreamingState in the request filter
1369        let Some(state) = ctx.stream_state.as_mut() else {
1370            return Ok(());
1371        };
1372
1373        // 1. Flush frames are empty and *not* EOS - just ignore them
1374        let Some(payload) = body.take() else {
1375            return Ok(());
1376        };
1377        if payload.is_empty() && !end_of_stream {
1378            return Ok(());
1379        };
1380
1381        // 2. Build the outgoing buffer.
1382        //    The incoming body already has the client's aws-chunked framing:
1383        //      <hex-size>;chunk-signature=<sig>\r\n<payload>\r\n
1384        //    We must strip that framing, extract the raw payload bytes, and
1385        //    then re-sign/re-frame for the Garage backend.
1386        let mut out = BytesMut::new();
1387        state.decode_buf.extend_from_slice(&payload);
1388
1389        while let Some((header_len, payload_len)) =
1390            signer::parse_aws_chunk_header(&state.decode_buf)
1391        {
1392            // total bytes needed: header + payload + trailing \r\n
1393            let total = header_len + payload_len + 2;
1394            if state.decode_buf.len() < total {
1395                // wait for more data
1396                break;
1397            }
1398            let raw_payload = state.decode_buf[header_len..header_len + payload_len].to_vec();
1399            use bytes::Buf;
1400            state.decode_buf.advance(total);
1401
1402            if payload_len == 0 {
1403                // This is the client's terminal empty chunk — skip it.
1404                // We will emit our own terminal chunk below when end_of_stream.
1405                break;
1406            }
1407            out.extend_from_slice(&state.sign_chunk(&raw_payload).map_err(|e| {
1408                error!("Failed to sign chunk: {e}");
1409                pingora::Error::new_str("Failed to sign chunk")
1410            })?);
1411        }
1412
1413        if end_of_stream {
1414            out.extend_from_slice(&state.final_chunk().map_err(|e| {
1415                error!("Failed to sign trailer: {e}");
1416                pingora::Error::new_str("Failed to sign trailer")
1417            })?);
1418            ctx.stream_state = None; // upload finished
1419        }
1420
1421        // 3. Hand the encoded bytes to Pingora
1422        *body = Some(out.freeze());
1423        Ok(())
1424    }
1425}
1426
1427/// Initialise the global [`tracing`] subscriber.
1428///
1429/// Configures a human-readable formatter with RFC 3339 timestamps.  The log
1430/// level is controlled by the `RUST_LOG` environment variable (e.g.
1431/// `RUST_LOG=object_storage_proxy=debug`).
1432///
1433/// This is called automatically by [`run_server`] and should not normally be
1434/// invoked by application code.
1435pub fn init_tracing() {
1436    tracing_subscriber::fmt()
1437        .with_timer(ChronoLocal::rfc_3339())
1438        .with_env_filter(EnvFilter::from_default_env())
1439        .init();
1440}
1441
1442/// Build and run the Pingora proxy server.
1443///
1444/// This is the Rust entry-point called from [`start_server`].  It:
1445/// 1. Initialises tracing.
1446/// 2. Parses the COS map and HMAC keystore from the Python objects in `run_args`.
1447/// 3. Creates the Pingora [`Server`], attaches HTTP and/or HTTPS listeners, and
1448///    enters the run-forever loop (blocking the calling thread).
1449///
1450/// # Panics
1451///
1452/// Panics if `run_args.cos_map` cannot be parsed, or if the TLS certificate /
1453/// key paths are missing when `https_port` is set.
1454pub fn run_server(py: Python, run_args: &ProxyServerConfig) {
1455    print_banner();
1456    init_tracing();
1457
1458    #[cfg(feature = "metrics")]
1459    {
1460        utils::metrics::init_metrics();
1461        if let Some(port) = run_args.metrics_port {
1462            // Spawn the metrics HTTP server on a background Tokio task.
1463            // `tokio::spawn` requires an active runtime; Pingora sets one up
1464            // during `my_server.bootstrap()` but we need a runtime here
1465            // before bootstrap, so we use a standalone one.
1466            std::thread::spawn(move || {
1467                tokio::runtime::Builder::new_current_thread()
1468                    .enable_all()
1469                    .build()
1470                    .expect("metrics runtime")
1471                    .block_on(utils::metrics::serve_metrics(port));
1472            });
1473        }
1474    }
1475
1476    if run_args.http_port.is_none() && run_args.https_port.is_none() {
1477        error!("At least one of http_port or https_port must be specified!");
1478        return;
1479    }
1480
1481    if let Some(http_port) = run_args.http_port {
1482        info!("starting HTTP server on port {}", http_port);
1483    }
1484
1485    if let Some(https_port) = run_args.https_port {
1486        info!("starting HTTPS server on port {}", https_port);
1487    }
1488
1489    let local_hmac_map = if Python::with_gil(|py| run_args.hmac_keystore.is_none(py)) {
1490        HashMap::new()
1491    } else {
1492        parse_hmac_list(py, &run_args.hmac_keystore).unwrap_or_default()
1493    };
1494
1495    debug!("HMAC keys: {:#?}", &local_hmac_map);
1496
1497    let cosmap = Arc::new(RwLock::new(
1498        parse_cos_map(py, &run_args.cos_map).expect("failed to parse cos_map"),
1499    ));
1500    let hmac_keystore = Arc::new(RwLock::new(local_hmac_map));
1501
1502    let mut my_server = Server::new(None).expect("failed to create pingora server");
1503    my_server.bootstrap();
1504
1505    let validator = run_args.validator.as_ref().map(|v| v.clone_ref(py));
1506    let hmac_fetcher = run_args.hmac_fetcher.as_ref().map(|v| v.clone_ref(py));
1507
1508    // Inspect the validator callable's arity once at startup.
1509    let validator_takes_request = run_args.validator.as_ref().map(|v| {
1510        Python::with_gil(|py| utils::functions::callable_accepts_request(py, v).unwrap_or(false))
1511    });
1512
1513    let auth_cache_instance = AuthCache::new(
1514        run_args
1515            .auth_cache_capacity
1516            .unwrap_or(utils::validator::AUTH_CACHE_DEFAULT_CAPACITY),
1517    );
1518
1519    let mut my_proxy = pingora::proxy::http_proxy_service(
1520        &my_server.configuration,
1521        MyProxy {
1522            cos_endpoint: "s3.eu-de.cloud-object-storage.appdomain.cloud".to_string(),
1523            cos_mapping: Arc::clone(&cosmap),
1524            hmac_keystore: Arc::clone(&hmac_keystore),
1525            secrets_cache: SecretsCache::new(),
1526            auth_cache: auth_cache_instance,
1527            validator,
1528            bucket_creds_fetcher: run_args
1529                .bucket_creds_fetcher
1530                .as_ref()
1531                .map(|v| v.clone_ref(py)),
1532            verify: run_args.verify,
1533            skip_signature_validation: run_args.skip_signature_validation,
1534            hmac_fetcher,
1535            tracker: UrlTracker::new(),
1536            max_presign_url_usage_attempts: run_args.max_presign_url_usage_attempts,
1537            server_name: "<osp⚡>".to_string(),
1538            validator_takes_request,
1539        },
1540    );
1541
1542    if run_args.threads.is_some() {
1543        my_proxy.threads = run_args.threads;
1544    }
1545
1546    debug!("Proxy service threads: {:?}", &my_proxy.threads);
1547
1548    if let Some(http_port) = run_args.http_port {
1549        info!("starting HTTP server on port {}", &http_port);
1550        let addr = format!("0.0.0.0:{}", http_port);
1551        my_proxy.add_tcp(addr.as_str());
1552    }
1553
1554    if let Some(https_port) = run_args.https_port {
1555        let cert_path =
1556            std::env::var("TLS_CERT_PATH").expect("Set TLS_CERT_PATH to the PEM certificate file");
1557        let key_path =
1558            std::env::var("TLS_KEY_PATH").expect("Set TLS_KEY_PATH to the PEM private-key file");
1559
1560        let mut tls = pingora::listeners::tls::TlsSettings::intermediate(&cert_path, &key_path)
1561            .expect("failed to build TLS settings");
1562
1563        tls.enable_h2();
1564        let https_addr = format!("0.0.0.0:{}", https_port);
1565        my_proxy.add_tls_with_settings(https_addr.as_str(), /*tcp_opts*/ None, tls);
1566    }
1567
1568    my_server.add_service(my_proxy);
1569
1570    debug!("{:?}", &my_server.configuration);
1571
1572    py.allow_threads(|| my_server.run_forever());
1573
1574    info!("server running ...");
1575}
1576
1577/// Start an HTTP + HTTPS reverse‑proxy for IBM COS.
1578///
1579/// Equivalent to running ``pingora`` with a custom handler.
1580///
1581/// Parameters
1582/// ----------
1583/// run_args:
1584///    A :py:class:`ProxyServerConfig` object containing the configuration for the server.
1585///     The configuration includes the following parameters:
1586///   - cos_map: A dictionary mapping bucket names to their respective COS configuration.
1587///     Each entry should contain the following
1588///     keys:
1589///        - host: The COS endpoint (e.g., "s3.eu-de.cloud-object-storage.appdomain.cloud")
1590///        - port: The port number (e.g., 443)
1591///        - api_key/apikey: The API key for the bucket (optional)
1592///        - ttl/time-to-live: The time-to-live for the API key in seconds (optional)
1593///   - bucket_creds_fetcher: Optional Python async callable that fetches the API key for a bucket.
1594///     The callable should accept a single argument, the bucket name.
1595///     It should return a string containing the API key.
1596///   - http_port: The HTTP port to listen on.
1597///   - https_port: The HTTPS port to listen on.
1598///   - validator: Optional Python async callable that validates the request.
1599///     The callable should accept two arguments, the access_key and the bucket name.
1600///     It should return a boolean indicating whether the request is valid.
1601///   - threads: Optional number of threads to use for the server.
1602///     If not specified, the server will use a single thread.
1603#[pyfunction]
1604pub fn start_server(py: Python, run_args: &ProxyServerConfig) -> PyResult<()> {
1605    rustls::crypto::ring::default_provider()
1606        .install_default()
1607        .expect("Failed to install rustls crypto provider");
1608
1609    dotenv().ok();
1610
1611    run_server(py, run_args);
1612
1613    Ok(())
1614}
1615
1616/// Enable the global request counter (disabled by default).
1617///
1618/// Once enabled every request proxied increments an atomic counter that can be
1619/// read with [`get_request_count`].  Useful for testing and load-measurement.
1620#[pyfunction]
1621fn enable_request_counting() {
1622    REQ_COUNTER_ENABLED.store(true, Ordering::Relaxed);
1623}
1624
1625/// Disable the global request counter.
1626#[pyfunction]
1627fn disable_request_counting() {
1628    REQ_COUNTER_ENABLED.store(false, Ordering::Relaxed);
1629}
1630
1631/// Return the total number of proxied requests since counting was enabled.
1632#[pyfunction]
1633fn get_request_count() -> PyResult<usize> {
1634    Ok(REQ_COUNTER.load(Ordering::Relaxed))
1635}
1636
1637#[pymodule]
1638fn object_storage_proxy(m: &Bound<'_, PyModule>) -> PyResult<()> {
1639    m.add_function(wrap_pyfunction!(start_server, m)?)?;
1640    m.add_class::<ProxyServerConfig>()?;
1641    m.add_class::<CosMapItem>()?;
1642    m.add_function(wrap_pyfunction!(enable_request_counting, m)?)?;
1643    m.add_function(wrap_pyfunction!(disable_request_counting, m)?)?;
1644    m.add_function(wrap_pyfunction!(get_request_count, m)?)?;
1645    Ok(())
1646}
1647
1648#[cfg(test)]
1649mod tests {
1650    use super::*;
1651
1652    // ── UrlTracker ────────────────────────────────────────────────────────────
1653
1654    #[test]
1655    fn url_tracker_new_is_empty() {
1656        let tracker = UrlTracker::new();
1657        assert!(tracker.get_all().is_empty());
1658    }
1659
1660    #[test]
1661    fn url_tracker_default_equals_new() {
1662        let t1 = UrlTracker::new();
1663        let t2 = UrlTracker::default();
1664        assert_eq!(t1.get_all().len(), t2.get_all().len());
1665    }
1666
1667    #[test]
1668    fn url_tracker_track_increments_count() {
1669        let tracker = UrlTracker::new();
1670        assert_eq!(tracker.get("http://example.com/key"), None);
1671        tracker.track("http://example.com/key");
1672        assert_eq!(tracker.get("http://example.com/key"), Some(1));
1673        tracker.track("http://example.com/key");
1674        assert_eq!(tracker.get("http://example.com/key"), Some(2));
1675    }
1676
1677    #[test]
1678    fn url_tracker_get_returns_none_for_unknown_url() {
1679        let tracker = UrlTracker::new();
1680        assert_eq!(tracker.get("http://example.com/missing"), None);
1681    }
1682
1683    #[test]
1684    fn url_tracker_get_all_returns_all_tracked_urls() {
1685        let tracker = UrlTracker::new();
1686        tracker.track("http://example.com/a");
1687        tracker.track("http://example.com/b");
1688        tracker.track("http://example.com/a");
1689        let mut all = tracker.get_all();
1690        all.sort_by_key(|(k, _)| k.clone());
1691        assert_eq!(all.len(), 2);
1692        assert_eq!(all[0], ("http://example.com/a".to_string(), 2));
1693        assert_eq!(all[1], ("http://example.com/b".to_string(), 1));
1694    }
1695}