diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index c1084961..6510aaf2 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -72,14 +72,16 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: dtolnay/rust-toolchain@1.66 + - uses: dtolnay/rust-toolchain@1.64 - name: Install protoc uses: taiki-e/install-action@v2 with: tool: protoc@3.20.3 - run: cargo update -p tokio --precise 1.38.1 - run: cargo update -p tokio-util --precise 0.7.11 - - run: cargo test -p tower-http --all-features + - run: cargo update -p flate2 --precise 1.0.35 + - run: cargo update -p once_cell --precise 1.20.3 + - run: cargo check -p tower-http --all-features style: needs: check diff --git a/examples/tonic-key-value-store/Cargo.toml b/examples/tonic-key-value-store/Cargo.toml index 4e7897a4..04de3c34 100644 --- a/examples/tonic-key-value-store/Cargo.toml +++ b/examples/tonic-key-value-store/Cargo.toml @@ -11,7 +11,7 @@ bytes = "1" hyper = { version = "0.14.4", features = ["full"] } prost = "0.11" tokio = { version = "1.2.0", features = ["full"] } -futures = "0.3" +futures-util = { version = "0.3", default-features = false } tokio-stream = { version = "0.1", features = ["sync", "net"] } tonic = "0.9" tower = { version = "0.5", features = ["full"] } diff --git a/examples/tonic-key-value-store/src/main.rs b/examples/tonic-key-value-store/src/main.rs index e615e0ed..a765400e 100644 --- a/examples/tonic-key-value-store/src/main.rs +++ b/examples/tonic-key-value-store/src/main.rs @@ -5,7 +5,7 @@ fn main() { /* use bytes::Bytes; use clap::Parser; -use futures::StreamExt; +use futures_util::StreamExt; use hyper::{ body::HttpBody, header::{self, HeaderValue}, diff --git a/test-files/empty.txt b/test-files/empty.txt new file mode 100644 index 00000000..e69de29b diff --git a/tower-http/CHANGELOG.md b/tower-http/CHANGELOG.md index 0c293002..0958beac 100644 --- a/tower-http/CHANGELOG.md +++ b/tower-http/CHANGELOG.md @@ -5,6 +5,37 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# 0.6.4 + +## Added + +- decompression: Support HTTP responses containing multiple ZSTD frames ([#548]) +- The `ServiceExt` trait for chaining layers onto an arbitrary http service just + like `ServiceBuilderExt` allows for `ServiceBuilder` ([#563]) + +## Fixed + +- Remove unnecessary trait bounds on `S::Error` for `Service` impls of + `RequestBodyTimeout` and `ResponseBodyTimeout` ([#533]) +- compression: Respect `is_end_stream` ([#535]) +- Fix a rare panic in `fs::ServeDir` ([#553]) +- Fix invalid `content-lenght` of 1 in response to range requests to empty + files ([#556]) +- In `AsyncRequireAuthorization`, use the original inner service after it is + ready, instead of using a clone ([#561]) + +[#533]: https://github.com/tower-rs/tower-http/pull/533 +[#535]: https://github.com/tower-rs/tower-http/pull/535 +[#548]: https://github.com/tower-rs/tower-http/pull/548 +[#553]: https://github.com/tower-rs/tower-http/pull/556 +[#556]: https://github.com/tower-rs/tower-http/pull/556 +[#561]: https://github.com/tower-rs/tower-http/pull/561 +[#563]: https://github.com/tower-rs/tower-http/pull/563 + +# 0.6.3 + +*This release was yanked because its definition of `ServiceExt` was quite unhelpful, in a way that's very unlikely that anybody would start depending on within the small timeframe before this was yanked, but that was technically breaking to change.* + # 0.6.2 ## Changed: diff --git a/tower-http/Cargo.toml b/tower-http/Cargo.toml index 69d28805..5026f7b1 100644 --- a/tower-http/Cargo.toml +++ b/tower-http/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "tower-http" description = "Tower middleware and utilities for HTTP clients and servers" -version = "0.6.2" +version = "0.6.4" authors = ["Tower Maintainers "] edition = "2018" license = "MIT" @@ -10,7 +10,7 @@ repository = "https://github.com/tower-rs/tower-http" homepage = "https://github.com/tower-rs/tower-http" categories = ["asynchronous", "network-programming", "web-programming"] keywords = ["io", "async", "futures", "service", "http"] -rust-version = "1.66" +rust-version = "1.64" [dependencies] bitflags = "2.0.2" diff --git a/tower-http/src/auth/async_require_authorization.rs b/tower-http/src/auth/async_require_authorization.rs index f086add2..00ab4b19 100644 --- a/tower-http/src/auth/async_require_authorization.rs +++ b/tower-http/src/auth/async_require_authorization.rs @@ -121,6 +121,7 @@ use http::{Request, Response}; use pin_project_lite::pin_project; use std::{ future::Future, + mem, pin::Pin, task::{ready, Context, Poll}, }; @@ -202,8 +203,10 @@ where } fn call(&mut self, req: Request) -> Self::Future { - let inner = self.inner.clone(); + let mut inner = self.inner.clone(); let authorize = self.auth.authorize(req); + // mem::swap due to https://docs.rs/tower/latest/tower/trait.Service.html#be-careful-when-cloning-inner-services + mem::swap(&mut self.inner, &mut inner); ResponseFuture { state: State::Authorize { authorize }, @@ -324,19 +327,15 @@ mod tests { type ResponseBody = Body; type Future = BoxFuture<'static, Result, Response>>; - fn authorize(&mut self, mut request: Request) -> Self::Future { + fn authorize(&mut self, request: Request) -> Self::Future { Box::pin(async move { let authorized = request .headers() .get(header::AUTHORIZATION) - .and_then(|it| it.to_str().ok()) - .and_then(|it| it.strip_prefix("Bearer ")) - .map(|it| it == "69420") - .unwrap_or(false); + .and_then(|auth| auth.to_str().ok()?.strip_prefix("Bearer ")) + == Some("69420"); if authorized { - let user_id = UserId("6969".to_owned()); - request.extensions_mut().insert(user_id); Ok(request) } else { Err(Response::builder() @@ -348,9 +347,6 @@ mod tests { } } - #[derive(Clone, Debug)] - struct UserId(String); - #[tokio::test] async fn require_async_auth_works() { let mut service = ServiceBuilder::new() diff --git a/tower-http/src/builder.rs b/tower-http/src/builder.rs index 58b789f2..85803855 100644 --- a/tower-http/src/builder.rs +++ b/tower-http/src/builder.rs @@ -5,6 +5,11 @@ use http::header::HeaderName; #[allow(unused_imports)] use tower_layer::Stack; +mod sealed { + #[allow(unreachable_pub, unused)] + pub trait Sealed {} +} + /// Extension trait that adds methods to [`tower::ServiceBuilder`] for adding middleware from /// tower-http. /// @@ -39,7 +44,7 @@ use tower_layer::Stack; /// ``` #[cfg(feature = "util")] // ^ work around rustdoc not inferring doc(cfg)s for cfg's from surrounding scopes -pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { +pub trait ServiceBuilderExt: sealed::Sealed + Sized { /// Propagate a header from the request to the response. /// /// See [`tower_http::propagate_header`] for more details. @@ -302,10 +307,7 @@ pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { where M: crate::request_id::MakeRequestId, { - self.set_request_id( - HeaderName::from_static(crate::request_id::X_REQUEST_ID), - make_request_id, - ) + self.set_request_id(crate::request_id::X_REQUEST_ID, make_request_id) } /// Propgate request ids from requests to responses. @@ -328,7 +330,7 @@ pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { fn propagate_x_request_id( self, ) -> ServiceBuilder> { - self.propagate_request_id(HeaderName::from_static(crate::request_id::X_REQUEST_ID)) + self.propagate_request_id(crate::request_id::X_REQUEST_ID) } /// Catch panics and convert them into `500 Internal Server` responses. @@ -366,7 +368,7 @@ pub trait ServiceBuilderExt: crate::sealed::Sealed + Sized { ) -> ServiceBuilder>; } -impl crate::sealed::Sealed for ServiceBuilder {} +impl sealed::Sealed for ServiceBuilder {} impl ServiceBuilderExt for ServiceBuilder { #[cfg(feature = "propagate-header")] diff --git a/tower-http/src/classify/grpc_errors_as_failures.rs b/tower-http/src/classify/grpc_errors_as_failures.rs index b88606b5..3fc96c33 100644 --- a/tower-http/src/classify/grpc_errors_as_failures.rs +++ b/tower-http/src/classify/grpc_errors_as_failures.rs @@ -126,7 +126,6 @@ impl GrpcCodeBitmask { /// Responses are considered successful if /// /// - `grpc-status` header value contains a success value. -/// default). /// - `grpc-status` header is missing. /// - `grpc-status` header value isn't a valid `String`. /// - `grpc-status` header value can't parsed into an `i32`. diff --git a/tower-http/src/classify/mod.rs b/tower-http/src/classify/mod.rs index 72c4db92..a3147843 100644 --- a/tower-http/src/classify/mod.rs +++ b/tower-http/src/classify/mod.rs @@ -362,7 +362,7 @@ pub enum ServerErrorsFailureClass { } impl fmt::Display for ServerErrorsFailureClass { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::StatusCode(code) => write!(f, "Status code: {}", code), Self::Error(error) => write!(f, "Error: {}", error), @@ -373,11 +373,15 @@ impl fmt::Display for ServerErrorsFailureClass { // Just verify that we can actually use this response classifier to determine retries as well #[cfg(test)] mod usable_for_retries { - #[allow(unused_imports)] - use super::*; + #![allow(dead_code)] + + use std::fmt; + use http::{Request, Response}; use tower::retry::Policy; + use super::{ClassifiedResponse, ClassifyResponse}; + trait IsRetryable { fn is_retryable(&self) -> bool; } diff --git a/tower-http/src/classify/status_in_range_is_error.rs b/tower-http/src/classify/status_in_range_is_error.rs index 934d08c5..8ff830b9 100644 --- a/tower-http/src/classify/status_in_range_is_error.rs +++ b/tower-http/src/classify/status_in_range_is_error.rs @@ -112,7 +112,7 @@ pub enum StatusInRangeFailureClass { } impl fmt::Display for StatusInRangeFailureClass { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::StatusCode(code) => write!(f, "Status code: {}", code), Self::Error(error) => write!(f, "Error: {}", error), @@ -131,23 +131,23 @@ mod tests { let classifier = StatusInRangeAsFailures::new(400..=599); assert!(matches!( - dbg!(classifier + classifier .clone() - .classify_response(&response_with_status(200))), + .classify_response(&response_with_status(200)), ClassifiedResponse::Ready(Ok(())), )); assert!(matches!( - dbg!(classifier + classifier .clone() - .classify_response(&response_with_status(400))), + .classify_response(&response_with_status(400)), ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode( StatusCode::BAD_REQUEST ))), )); assert!(matches!( - dbg!(classifier.classify_response(&response_with_status(500))), + classifier.classify_response(&response_with_status(500)), ClassifiedResponse::Ready(Err(StatusInRangeFailureClass::StatusCode( StatusCode::INTERNAL_SERVER_ERROR ))), diff --git a/tower-http/src/compression/body.rs b/tower-http/src/compression/body.rs index 989685c1..259e4a27 100644 --- a/tower-http/src/compression/body.rs +++ b/tower-http/src/compression/body.rs @@ -277,6 +277,14 @@ where http_body::SizeHint::new() } } + + fn is_end_stream(&self) -> bool { + if let BodyInner::Identity { inner } = &self.inner { + inner.is_end_stream() + } else { + false + } + } } #[cfg(feature = "compression-gzip")] diff --git a/tower-http/src/decompression/body.rs b/tower-http/src/decompression/body.rs index 9378e5ca..a2970d65 100644 --- a/tower-http/src/decompression/body.rs +++ b/tower-http/src/decompression/body.rs @@ -397,7 +397,9 @@ where type Output = ZstdDecoder; fn apply(input: Self::Input, _quality: CompressionLevel) -> Self::Output { - ZstdDecoder::new(input) + let mut decoder = ZstdDecoder::new(input); + decoder.multiple_members(true); + decoder } fn get_pin_mut(pinned: Pin<&mut Self::Output>) -> Pin<&mut Self::Input> { diff --git a/tower-http/src/decompression/mod.rs b/tower-http/src/decompression/mod.rs index 708df439..50d4d5fa 100644 --- a/tower-http/src/decompression/mod.rs +++ b/tower-http/src/decompression/mod.rs @@ -168,6 +168,24 @@ mod tests { assert_eq!(decompressed_data, "Hello, World!"); } + #[tokio::test] + async fn decompress_multi_zstd() { + let mut client = Decompression::new(service_fn(handle_multi_zstd)); + + let req = Request::builder() + .header("accept-encoding", "zstd") + .body(Body::empty()) + .unwrap(); + let res = client.ready().await.unwrap().call(req).await.unwrap(); + + // read the body, it will be decompressed automatically + let body = res.into_body(); + let decompressed_data = + String::from_utf8(body.collect().await.unwrap().to_bytes().to_vec()).unwrap(); + + assert_eq!(decompressed_data, "Hello, World!"); + } + async fn handle_multi_gz(_req: Request) -> Result, Infallible> { let mut buf = Vec::new(); let mut enc1 = GzEncoder::new(&mut buf, Default::default()); @@ -184,6 +202,22 @@ mod tests { Ok(res) } + async fn handle_multi_zstd(_req: Request) -> Result, Infallible> { + let mut buf = Vec::new(); + let mut enc1 = zstd::Encoder::new(&mut buf, Default::default()).unwrap(); + enc1.write_all(b"Hello, ").unwrap(); + enc1.finish().unwrap(); + + let mut enc2 = zstd::Encoder::new(&mut buf, Default::default()).unwrap(); + enc2.write_all(b"World!").unwrap(); + enc2.finish().unwrap(); + + let mut res = Response::new(Body::from(buf)); + res.headers_mut() + .insert("content-encoding", "zstd".parse().unwrap()); + Ok(res) + } + #[allow(dead_code)] async fn is_compatible_with_hyper() { let client = diff --git a/tower-http/src/lib.rs b/tower-http/src/lib.rs index 8d254e1d..372bef8c 100644 --- a/tower-http/src/lib.rs +++ b/tower-http/src/lib.rs @@ -189,7 +189,6 @@ clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, - clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, @@ -335,10 +334,12 @@ pub mod services; #[cfg(feature = "util")] mod builder; +#[cfg(feature = "util")] +mod service_ext; #[cfg(feature = "util")] #[doc(inline)] -pub use self::builder::ServiceBuilderExt; +pub use self::{builder::ServiceBuilderExt, service_ext::ServiceExt}; #[cfg(feature = "validate-request")] pub mod validate_request; @@ -370,8 +371,3 @@ pub enum LatencyUnit { /// Alias for a type-erased error type. pub type BoxError = Box; - -mod sealed { - #[allow(unreachable_pub, unused)] - pub trait Sealed {} -} diff --git a/tower-http/src/request_id.rs b/tower-http/src/request_id.rs index 1db2d02a..3c8c43fa 100644 --- a/tower-http/src/request_id.rs +++ b/tower-http/src/request_id.rs @@ -181,7 +181,7 @@ use tower_layer::Layer; use tower_service::Service; use uuid::Uuid; -pub(crate) const X_REQUEST_ID: &str = "x-request-id"; +pub(crate) const X_REQUEST_ID: HeaderName = HeaderName::from_static("x-request-id"); /// Trait for producing [`RequestId`]s. /// @@ -246,7 +246,7 @@ impl SetRequestIdLayer { where M: MakeRequestId, { - SetRequestIdLayer::new(HeaderName::from_static(X_REQUEST_ID), make_request_id) + SetRequestIdLayer::new(X_REQUEST_ID, make_request_id) } } @@ -299,11 +299,7 @@ impl SetRequestId { where M: MakeRequestId, { - Self::new( - inner, - HeaderName::from_static(X_REQUEST_ID), - make_request_id, - ) + Self::new(inner, X_REQUEST_ID, make_request_id) } define_inner_service_accessors!(); @@ -365,7 +361,7 @@ impl PropagateRequestIdLayer { /// Create a new `PropagateRequestIdLayer` that uses `x-request-id` as the header name. pub fn x_request_id() -> Self { - Self::new(HeaderName::from_static(X_REQUEST_ID)) + Self::new(X_REQUEST_ID) } } @@ -397,7 +393,7 @@ impl PropagateRequestId { /// Create a new `PropagateRequestId` that uses `x-request-id` as the header name. pub fn x_request_id(inner: S) -> Self { - Self::new(inner, HeaderName::from_static(X_REQUEST_ID)) + Self::new(inner, X_REQUEST_ID) } define_inner_service_accessors!(); diff --git a/tower-http/src/service_ext.rs b/tower-http/src/service_ext.rs new file mode 100644 index 00000000..3221afab --- /dev/null +++ b/tower-http/src/service_ext.rs @@ -0,0 +1,429 @@ +#[allow(unused_imports)] +use http::header::HeaderName; + +/// Extension trait that adds methods to any [`Service`] for adding middleware from +/// tower-http. +/// +/// [`Service`]: tower::Service +#[cfg(feature = "util")] +// ^ work around rustdoc not inferring doc(cfg)s for cfg's from surrounding scopes +pub trait ServiceExt { + /// Propagate a header from the request to the response. + /// + /// See [`tower_http::propagate_header`] for more details. + /// + /// [`tower_http::propagate_header`]: crate::propagate_header + #[cfg(feature = "propagate-header")] + fn propagate_header(self, header: HeaderName) -> crate::propagate_header::PropagateHeader + where + Self: Sized, + { + crate::propagate_header::PropagateHeader::new(self, header) + } + + /// Add some shareable value to [request extensions]. + /// + /// See [`tower_http::add_extension`] for more details. + /// + /// [`tower_http::add_extension`]: crate::add_extension + /// [request extensions]: https://docs.rs/http/latest/http/struct.Extensions.html + #[cfg(feature = "add-extension")] + fn add_extension(self, value: T) -> crate::add_extension::AddExtension + where + Self: Sized, + { + crate::add_extension::AddExtension::new(self, value) + } + + /// Apply a transformation to the request body. + /// + /// See [`tower_http::map_request_body`] for more details. + /// + /// [`tower_http::map_request_body`]: crate::map_request_body + #[cfg(feature = "map-request-body")] + fn map_request_body(self, f: F) -> crate::map_request_body::MapRequestBody + where + Self: Sized, + { + crate::map_request_body::MapRequestBody::new(self, f) + } + + /// Apply a transformation to the response body. + /// + /// See [`tower_http::map_response_body`] for more details. + /// + /// [`tower_http::map_response_body`]: crate::map_response_body + #[cfg(feature = "map-response-body")] + fn map_response_body(self, f: F) -> crate::map_response_body::MapResponseBody + where + Self: Sized, + { + crate::map_response_body::MapResponseBody::new(self, f) + } + + /// Compresses response bodies. + /// + /// See [`tower_http::compression`] for more details. + /// + /// [`tower_http::compression`]: crate::compression + #[cfg(any( + feature = "compression-br", + feature = "compression-deflate", + feature = "compression-gzip", + feature = "compression-zstd", + ))] + fn compression(self) -> crate::compression::Compression + where + Self: Sized, + { + crate::compression::Compression::new(self) + } + + /// Decompress response bodies. + /// + /// See [`tower_http::decompression`] for more details. + /// + /// [`tower_http::decompression`]: crate::decompression + #[cfg(any( + feature = "decompression-br", + feature = "decompression-deflate", + feature = "decompression-gzip", + feature = "decompression-zstd", + ))] + fn decompression(self) -> crate::decompression::Decompression + where + Self: Sized, + { + crate::decompression::Decompression::new(self) + } + + /// High level tracing that classifies responses using HTTP status codes. + /// + /// This method does not support customizing the output, to do that use [`TraceLayer`] + /// instead. + /// + /// See [`tower_http::trace`] for more details. + /// + /// [`tower_http::trace`]: crate::trace + /// [`TraceLayer`]: crate::trace::TraceLayer + #[cfg(feature = "trace")] + fn trace_for_http(self) -> crate::trace::Trace + where + Self: Sized, + { + crate::trace::Trace::new_for_http(self) + } + + /// High level tracing that classifies responses using gRPC headers. + /// + /// This method does not support customizing the output, to do that use [`TraceLayer`] + /// instead. + /// + /// See [`tower_http::trace`] for more details. + /// + /// [`tower_http::trace`]: crate::trace + /// [`TraceLayer`]: crate::trace::TraceLayer + #[cfg(feature = "trace")] + fn trace_for_grpc(self) -> crate::trace::Trace + where + Self: Sized, + { + crate::trace::Trace::new_for_grpc(self) + } + + /// Follow redirect resposes using the [`Standard`] policy. + /// + /// See [`tower_http::follow_redirect`] for more details. + /// + /// [`tower_http::follow_redirect`]: crate::follow_redirect + /// [`Standard`]: crate::follow_redirect::policy::Standard + #[cfg(feature = "follow-redirect")] + fn follow_redirects( + self, + ) -> crate::follow_redirect::FollowRedirect + where + Self: Sized, + { + crate::follow_redirect::FollowRedirect::new(self) + } + + /// Mark headers as [sensitive] on both requests and responses. + /// + /// See [`tower_http::sensitive_headers`] for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + /// [`tower_http::sensitive_headers`]: crate::sensitive_headers + #[cfg(feature = "sensitive-headers")] + fn sensitive_headers( + self, + headers: impl IntoIterator, + ) -> crate::sensitive_headers::SetSensitiveHeaders + where + Self: Sized, + { + use tower_layer::Layer as _; + crate::sensitive_headers::SetSensitiveHeadersLayer::new(headers).layer(self) + } + + /// Mark headers as [sensitive] on requests. + /// + /// See [`tower_http::sensitive_headers`] for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + /// [`tower_http::sensitive_headers`]: crate::sensitive_headers + #[cfg(feature = "sensitive-headers")] + fn sensitive_request_headers( + self, + headers: impl IntoIterator, + ) -> crate::sensitive_headers::SetSensitiveRequestHeaders + where + Self: Sized, + { + crate::sensitive_headers::SetSensitiveRequestHeaders::new(self, headers) + } + + /// Mark headers as [sensitive] on responses. + /// + /// See [`tower_http::sensitive_headers`] for more details. + /// + /// [sensitive]: https://docs.rs/http/latest/http/header/struct.HeaderValue.html#method.set_sensitive + /// [`tower_http::sensitive_headers`]: crate::sensitive_headers + #[cfg(feature = "sensitive-headers")] + fn sensitive_response_headers( + self, + headers: impl IntoIterator, + ) -> crate::sensitive_headers::SetSensitiveResponseHeaders + where + Self: Sized, + { + crate::sensitive_headers::SetSensitiveResponseHeaders::new(self, headers) + } + + /// Insert a header into the request. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn override_request_header( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetRequestHeader + where + Self: Sized, + { + crate::set_header::SetRequestHeader::overriding(self, header_name, make) + } + + /// Append a header into the request. + /// + /// If previous values exist, the header will have multiple values. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn append_request_header( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetRequestHeader + where + Self: Sized, + { + crate::set_header::SetRequestHeader::appending(self, header_name, make) + } + + /// Insert a header into the request, if the header is not already present. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn insert_request_header_if_not_present( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetRequestHeader + where + Self: Sized, + { + crate::set_header::SetRequestHeader::if_not_present(self, header_name, make) + } + + /// Insert a header into the response. + /// + /// If a previous value exists for the same header, it is removed and replaced with the new + /// header value. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn override_response_header( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetResponseHeader + where + Self: Sized, + { + crate::set_header::SetResponseHeader::overriding(self, header_name, make) + } + + /// Append a header into the response. + /// + /// If previous values exist, the header will have multiple values. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn append_response_header( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetResponseHeader + where + Self: Sized, + { + crate::set_header::SetResponseHeader::appending(self, header_name, make) + } + + /// Insert a header into the response, if the header is not already present. + /// + /// See [`tower_http::set_header`] for more details. + /// + /// [`tower_http::set_header`]: crate::set_header + #[cfg(feature = "set-header")] + fn insert_response_header_if_not_present( + self, + header_name: HeaderName, + make: M, + ) -> crate::set_header::SetResponseHeader + where + Self: Sized, + { + crate::set_header::SetResponseHeader::if_not_present(self, header_name, make) + } + + /// Add request id header and extension. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn set_request_id( + self, + header_name: HeaderName, + make_request_id: M, + ) -> crate::request_id::SetRequestId + where + Self: Sized, + M: crate::request_id::MakeRequestId, + { + crate::request_id::SetRequestId::new(self, header_name, make_request_id) + } + + /// Add request id header and extension, using `x-request-id` as the header name. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn set_x_request_id(self, make_request_id: M) -> crate::request_id::SetRequestId + where + Self: Sized, + M: crate::request_id::MakeRequestId, + { + self.set_request_id(crate::request_id::X_REQUEST_ID, make_request_id) + } + + /// Propgate request ids from requests to responses. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn propagate_request_id( + self, + header_name: HeaderName, + ) -> crate::request_id::PropagateRequestId + where + Self: Sized, + { + crate::request_id::PropagateRequestId::new(self, header_name) + } + + /// Propgate request ids from requests to responses, using `x-request-id` as the header name. + /// + /// See [`tower_http::request_id`] for more details. + /// + /// [`tower_http::request_id`]: crate::request_id + #[cfg(feature = "request-id")] + fn propagate_x_request_id(self) -> crate::request_id::PropagateRequestId + where + Self: Sized, + { + self.propagate_request_id(crate::request_id::X_REQUEST_ID) + } + + /// Catch panics and convert them into `500 Internal Server` responses. + /// + /// See [`tower_http::catch_panic`] for more details. + /// + /// [`tower_http::catch_panic`]: crate::catch_panic + #[cfg(feature = "catch-panic")] + fn catch_panic( + self, + ) -> crate::catch_panic::CatchPanic + where + Self: Sized, + { + crate::catch_panic::CatchPanic::new(self) + } + + /// Intercept requests with over-sized payloads and convert them into + /// `413 Payload Too Large` responses. + /// + /// See [`tower_http::limit`] for more details. + /// + /// [`tower_http::limit`]: crate::limit + #[cfg(feature = "limit")] + fn request_body_limit(self, limit: usize) -> crate::limit::RequestBodyLimit + where + Self: Sized, + { + crate::limit::RequestBodyLimit::new(self, limit) + } + + /// Remove trailing slashes from paths. + /// + /// See [`tower_http::normalize_path`] for more details. + /// + /// [`tower_http::normalize_path`]: crate::normalize_path + #[cfg(feature = "normalize-path")] + fn trim_trailing_slash(self) -> crate::normalize_path::NormalizePath + where + Self: Sized, + { + crate::normalize_path::NormalizePath::trim_trailing_slash(self) + } +} + +impl ServiceExt for T {} + +#[cfg(all(test, feature = "fs", feature = "add-extension"))] +mod tests { + use super::ServiceExt; + use crate::services; + + #[allow(dead_code)] + fn test_type_inference() { + let _svc = services::fs::ServeDir::new(".").add_extension("&'static str"); + } +} diff --git a/tower-http/src/services/fs/serve_dir/future.rs b/tower-http/src/services/fs/serve_dir/future.rs index 8b255f72..305029be 100644 --- a/tower-http/src/services/fs/serve_dir/future.rs +++ b/tower-http/src/services/fs/serve_dir/future.rs @@ -122,6 +122,12 @@ where break Poll::Ready(Ok(response_with_status(StatusCode::NOT_MODIFIED))); } + Ok(OpenFileOutput::InvalidRedirectUri) => { + break Poll::Ready(Ok(response_with_status( + StatusCode::INTERNAL_SERVER_ERROR, + ))); + } + Err(err) => { #[cfg(unix)] // 20 = libc::ENOTDIR => "not a directory @@ -263,12 +269,18 @@ fn build_response(output: FileOpened) -> Response { empty_body() }; + let content_length = if size == 0 { + 0 + } else { + range.end() - range.start() + 1 + }; + builder .header( header::CONTENT_RANGE, format!("bytes {}-{}/{}", range.start(), range.end(), size), ) - .header(header::CONTENT_LENGTH, range.end() - range.start() + 1) + .header(header::CONTENT_LENGTH, content_length) .status(StatusCode::PARTIAL_CONTENT) .body(body) .unwrap() diff --git a/tower-http/src/services/fs/serve_dir/open_file.rs b/tower-http/src/services/fs/serve_dir/open_file.rs index f182d422..9ddedd8a 100644 --- a/tower-http/src/services/fs/serve_dir/open_file.rs +++ b/tower-http/src/services/fs/serve_dir/open_file.rs @@ -22,6 +22,7 @@ pub(super) enum OpenFileOutput { FileNotFound, PreconditionFailed, NotModified, + InvalidRedirectUri, } pub(super) struct FileOpened { @@ -215,10 +216,9 @@ async fn open_file_with_fallback( // Remove the encoding from the negotiated_encodings since the file doesn't exist negotiated_encoding .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); - continue; } (Err(err), _) => return Err(err), - }; + } }; Ok((file, encoding)) } @@ -242,10 +242,9 @@ async fn file_metadata_with_fallback( // Remove the encoding from the negotiated_encodings since the file doesn't exist negotiated_encoding .retain(|(negotiated_encoding, _)| *negotiated_encoding != encoding); - continue; } (Err(err), _) => return Err(err), - }; + } }; Ok((file, encoding)) } @@ -267,8 +266,11 @@ async fn maybe_redirect_or_append_path( path_to_file.push("index.html"); None } else { - let location = - HeaderValue::from_str(&append_slash_on_path(uri.clone()).to_string()).unwrap(); + let uri = match append_slash_on_path(uri.clone()) { + Ok(uri) => uri, + Err(err) => return Some(err), + }; + let location = HeaderValue::from_str(&uri.to_string()).unwrap(); Some(OpenFileOutput::Redirect { location }) } } @@ -289,7 +291,7 @@ async fn is_dir(path_to_file: &Path) -> bool { .map_or(false, |meta_data| meta_data.is_dir()) } -fn append_slash_on_path(uri: Uri) -> Uri { +fn append_slash_on_path(uri: Uri) -> Result { let http::uri::Parts { scheme, authority, @@ -317,7 +319,10 @@ fn append_slash_on_path(uri: Uri) -> Uri { uri_builder.path_and_query("/") }; - uri_builder.build().unwrap() + uri_builder.build().map_err(|err| { + tracing::error!(?err, "redirect uri failed to build"); + OpenFileOutput::InvalidRedirectUri + }) } #[test] diff --git a/tower-http/src/services/fs/serve_dir/tests.rs b/tower-http/src/services/fs/serve_dir/tests.rs index 1fd768c2..ea1c543e 100644 --- a/tower-http/src/services/fs/serve_dir/tests.rs +++ b/tower-http/src/services/fs/serve_dir/tests.rs @@ -506,6 +506,25 @@ async fn access_space_percent_encoded_uri_path() { assert_eq!(res.headers()["content-type"], "text/plain"); } +#[tokio::test] +async fn read_partial_empty() { + let svc = ServeDir::new("../test-files"); + + let req = Request::builder() + .uri("/empty.txt") + .header("Range", "bytes=0-") + .body(Body::empty()) + .unwrap(); + let res = svc.oneshot(req).await.unwrap(); + + assert_eq!(res.status(), StatusCode::PARTIAL_CONTENT); + assert_eq!(res.headers()["content-length"], "0"); + assert_eq!(res.headers()["content-range"], "bytes 0-0/0"); + + let body = to_bytes(res.into_body()).await.ok().unwrap(); + assert!(body.is_empty()); +} + #[tokio::test] async fn read_partial_in_bounds() { let svc = ServeDir::new(".."); diff --git a/tower-http/src/timeout/service.rs b/tower-http/src/timeout/service.rs index 8371b03f..230fe717 100644 --- a/tower-http/src/timeout/service.rs +++ b/tower-http/src/timeout/service.rs @@ -165,7 +165,6 @@ impl RequestBodyTimeout { impl Service> for RequestBodyTimeout where S: Service>>, - S::Error: Into>, { type Response = S::Response; type Error = S::Error; @@ -212,7 +211,6 @@ pub struct ResponseBodyTimeout { impl Service> for ResponseBodyTimeout where S: Service, Response = Response>, - S::Error: Into>, { type Response = Response>; type Error = S::Error; diff --git a/tower-http/src/trace/mod.rs b/tower-http/src/trace/mod.rs index 65734a42..ec5036aa 100644 --- a/tower-http/src/trace/mod.rs +++ b/tower-http/src/trace/mod.rs @@ -369,9 +369,9 @@ //! [`TraceLayer`] comes with convenience methods for using common classifiers: //! //! - [`TraceLayer::new_for_http`] classifies based on the status code. It doesn't consider -//! streaming responses. +//! streaming responses. //! - [`TraceLayer::new_for_grpc`] classifies based on the gRPC protocol and supports streaming -//! responses. +//! responses. //! //! [tracing]: https://crates.io/crates/tracing //! [`Service`]: tower_service::Service @@ -516,7 +516,7 @@ mod tests { tracing::info_span!("test-span", foo = tracing::field::Empty) }) .on_request(|_req: &Request, span: &Span| { - span.record("foo", &42); + span.record("foo", 42); ON_REQUEST_COUNT.fetch_add(1, Ordering::SeqCst); }) .on_response(|_res: &Response, _latency: Duration, _span: &Span| {