diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 000000000..08891d83f --- /dev/null +++ b/.editorconfig @@ -0,0 +1,8 @@ +root = true + +[*] +indent_style = space +indent_size = 4 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true \ No newline at end of file diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 000000000..e29eb8464 --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,14 @@ +changelog: + categories: + - title: SemVer Major + labels: + - ⚠️ semver/major + - title: SemVer Minor + labels: + - 🆕 semver/minor + - title: SemVer Patch + labels: + - 🔨 semver/patch + - title: Other Changes + labels: + - semver/none diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 000000000..3bf5a95ec --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,23 @@ +name: Main + +on: + push: + branches: [main] + schedule: + - cron: "0 8,20 * * *" + +jobs: + unit-tests: + name: Unit tests + uses: apple/swift-nio/.github/workflows/unit_tests.yml@main + with: + linux_5_10_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_1_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_nightly_next_arguments_override: "--explicit-target-dependency-import-check error" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error" + + static-sdk: + name: Static SDK + # Workaround https://github.com/nektos/act/issues/1875 + uses: apple/swift-nio/.github/workflows/static_sdk.yml@main diff --git a/.github/workflows/pull_request.yml b/.github/workflows/pull_request.yml new file mode 100644 index 000000000..8036d7ad7 --- /dev/null +++ b/.github/workflows/pull_request.yml @@ -0,0 +1,32 @@ +name: PR + +on: + pull_request: + types: [opened, reopened, synchronize] + +jobs: + soundness: + name: Soundness + uses: swiftlang/github-workflows/.github/workflows/soundness.yml@main + with: + license_header_check_project_name: "AsyncHTTPClient" + unit-tests: + name: Unit tests + uses: apple/swift-nio/.github/workflows/unit_tests.yml@main + with: + linux_5_10_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_0_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_6_1_arguments_override: "--explicit-target-dependency-import-check error -Xswiftc -warnings-as-errors" + linux_nightly_next_arguments_override: "--explicit-target-dependency-import-check error" + linux_nightly_main_arguments_override: "--explicit-target-dependency-import-check error" + + cxx-interop: + name: Cxx interop + uses: apple/swift-nio/.github/workflows/cxx_interop.yml@main + with: + linux_5_9_enabled: false + + static-sdk: + name: Static SDK + # Workaround https://github.com/nektos/act/issues/1875 + uses: apple/swift-nio/.github/workflows/static_sdk.yml@main diff --git a/.github/workflows/pull_request_label.yml b/.github/workflows/pull_request_label.yml new file mode 100644 index 000000000..8fd47c13f --- /dev/null +++ b/.github/workflows/pull_request_label.yml @@ -0,0 +1,18 @@ +name: PR label + +on: + pull_request: + types: [labeled, unlabeled, opened, reopened, synchronize] + +jobs: + semver-label-check: + name: Semantic version label check + runs-on: ubuntu-latest + timeout-minutes: 1 + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + persist-credentials: false + - name: Check for Semantic Version label + uses: apple/swift-nio/.github/actions/pull_request_semver_label_checker@main diff --git a/.licenseignore b/.licenseignore new file mode 100644 index 000000000..edceaab62 --- /dev/null +++ b/.licenseignore @@ -0,0 +1,37 @@ +.gitignore +**/.gitignore +.licenseignore +.gitattributes +.git-blame-ignore-revs +.mailfilter +.mailmap +.spi.yml +.swift-format +.editorconfig +.github/* +*.md +*.txt +*.yml +*.yaml +*.json +Package.swift +**/Package.swift +Package@-*.swift +**/Package@-*.swift +Package.resolved +**/Package.resolved +Makefile +*.modulemap +**/*.modulemap +**/*.docc/* +*.xcprivacy +**/*.xcprivacy +*.symlink +**/*.symlink +Dockerfile +**/Dockerfile +.dockerignore +Snippets/* +dev/git.commit.template +.unacceptablelanguageignore +Tests/AsyncHTTPClientTests/Resources/*.pem diff --git a/.swift-format b/.swift-format new file mode 100644 index 000000000..7e8ae7391 --- /dev/null +++ b/.swift-format @@ -0,0 +1,68 @@ +{ + "version" : 1, + "indentation" : { + "spaces" : 4 + }, + "tabWidth" : 4, + "fileScopedDeclarationPrivacy" : { + "accessLevel" : "private" + }, + "spacesAroundRangeFormationOperators" : false, + "indentConditionalCompilationBlocks" : false, + "indentSwitchCaseLabels" : false, + "lineBreakAroundMultilineExpressionChainComponents" : false, + "lineBreakBeforeControlFlowKeywords" : false, + "lineBreakBeforeEachArgument" : true, + "lineBreakBeforeEachGenericRequirement" : true, + "lineLength" : 120, + "maximumBlankLines" : 1, + "respectsExistingLineBreaks" : true, + "prioritizeKeepingFunctionOutputTogether" : true, + "noAssignmentInExpressions" : { + "allowedFunctions" : [ + "XCTAssertNoThrow", + "XCTAssertThrowsError" + ] + }, + "rules" : { + "AllPublicDeclarationsHaveDocumentation" : false, + "AlwaysUseLiteralForEmptyCollectionInit" : false, + "AlwaysUseLowerCamelCase" : false, + "AmbiguousTrailingClosureOverload" : true, + "BeginDocumentationCommentWithOneLineSummary" : false, + "DoNotUseSemicolons" : true, + "DontRepeatTypeInStaticProperties" : true, + "FileScopedDeclarationPrivacy" : true, + "FullyIndirectEnum" : true, + "GroupNumericLiterals" : true, + "IdentifiersMustBeASCII" : true, + "NeverForceUnwrap" : false, + "NeverUseForceTry" : false, + "NeverUseImplicitlyUnwrappedOptionals" : false, + "NoAccessLevelOnExtensionDeclaration" : true, + "NoAssignmentInExpressions" : true, + "NoBlockComments" : true, + "NoCasesWithOnlyFallthrough" : true, + "NoEmptyTrailingClosureParentheses" : true, + "NoLabelsInCasePatterns" : true, + "NoLeadingUnderscores" : false, + "NoParensAroundConditions" : true, + "NoVoidReturnOnFunctionSignature" : true, + "OmitExplicitReturns" : true, + "OneCasePerLine" : true, + "OneVariableDeclarationPerLine" : true, + "OnlyOneTrailingClosureArgument" : true, + "OrderedImports" : true, + "ReplaceForEachWithForLoop" : true, + "ReturnVoidInsteadOfEmptyTuple" : true, + "UseEarlyExits" : false, + "UseExplicitNilCheckInConditions" : false, + "UseLetInEveryBoundCaseVariable" : false, + "UseShorthandTypeNames" : true, + "UseSingleLinePropertyGetter" : false, + "UseSynthesizedInitializer" : false, + "UseTripleSlashForDocumentationComments" : true, + "UseWhereClausesInForLoops" : false, + "ValidateDocumentationComments" : false + } +} diff --git a/.swiftformat b/.swiftformat deleted file mode 100644 index 7b7c486ea..000000000 --- a/.swiftformat +++ /dev/null @@ -1,24 +0,0 @@ -# file options - ---swiftversion 5.4 ---exclude .build - -# format options - ---self insert ---patternlet inline ---ranges nospace ---stripunusedargs unnamed-only ---ifdef no-indent ---extensionacl on-declarations ---disable typeSugar # https://github.com/nicklockwood/SwiftFormat/issues/636 ---disable andOperator ---disable wrapMultilineStatementBraces ---disable enumNamespaces ---disable redundantExtensionACL ---disable redundantReturn ---disable preferKeyPath ---disable sortedSwitchCases ---disable numberFormatting - -# rules diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3803bb618..dddcb3ba4 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -65,10 +65,10 @@ We require that your commit messages match our template. The easiest way to do t git config commit.template dev/git.commit.template -### Make sure Tests work on Linux -AsyncHTTPClient uses XCTest to run tests on both macOS and Linux. While the macOS version of XCTest is able to use the Objective-C runtime to discover tests at execution time, the Linux version is not. -For this reason, whenever you add new tests **you have to run a script** that generates the hooks needed to run those tests on Linux, or our CI will complain that the tests are not all present on Linux. To do this, merely execute `ruby ./scripts/generate_linux_tests.rb` at the root of the package and check the changes it made. +### Run CI checks locally + +You can run the Github Actions workflows locally using [act](https://github.com/nektos/act). For detailed steps on how to do this please see [https://github.com/swiftlang/github-workflows?tab=readme-ov-file#running-workflows-locally](https://github.com/swiftlang/github-workflows?tab=readme-ov-file#running-workflows-locally). ## How to contribute your work diff --git a/Examples/GetHTML/GetHTML.swift b/Examples/GetHTML/GetHTML.swift index dfefa922b..ca3bacbea 100644 --- a/Examples/GetHTML/GetHTML.swift +++ b/Examples/GetHTML/GetHTML.swift @@ -18,12 +18,12 @@ import NIOCore @main struct GetHTML { static func main() async throws { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) do { let request = HTTPClientRequest(url: "https://apple.com") let response = try await httpClient.execute(request, timeout: .seconds(30)) print("HTTP head", response) - let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB + let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB print(String(buffer: body)) } catch { print("request failed:", error) diff --git a/Examples/GetJSON/GetJSON.swift b/Examples/GetJSON/GetJSON.swift index ae58ffeaa..1af7a5144 100644 --- a/Examples/GetJSON/GetJSON.swift +++ b/Examples/GetJSON/GetJSON.swift @@ -33,12 +33,12 @@ struct Comic: Codable { @main struct GetJSON { static func main() async throws { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) do { let request = HTTPClientRequest(url: "https://xkcd.com/info.0.json") let response = try await httpClient.execute(request, timeout: .seconds(30)) print("HTTP head", response) - let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB + let body = try await response.body.collect(upTo: 1024 * 1024) // 1 MB // we use an overload defined in `NIOFoundationCompat` for `decode(_:from:)` to // efficiently decode from a `ByteBuffer` let comic = try JSONDecoder().decode(Comic.self, from: body) diff --git a/Examples/Package.swift b/Examples/Package.swift index 696092cba..9986b17b5 100644 --- a/Examples/Package.swift +++ b/Examples/Package.swift @@ -43,7 +43,8 @@ let package = Package( dependencies: [ .product(name: "AsyncHTTPClient", package: "async-http-client"), .product(name: "NIOCore", package: "swift-nio"), - ], path: "GetHTML" + ], + path: "GetHTML" ), .executableTarget( name: "GetJSON", @@ -51,14 +52,16 @@ let package = Package( .product(name: "AsyncHTTPClient", package: "async-http-client"), .product(name: "NIOCore", package: "swift-nio"), .product(name: "NIOFoundationCompat", package: "swift-nio"), - ], path: "GetJSON" + ], + path: "GetJSON" ), .executableTarget( name: "StreamingByteCounter", dependencies: [ .product(name: "AsyncHTTPClient", package: "async-http-client"), .product(name: "NIOCore", package: "swift-nio"), - ], path: "StreamingByteCounter" + ], + path: "StreamingByteCounter" ), ] ) diff --git a/Examples/StreamingByteCounter/StreamingByteCounter.swift b/Examples/StreamingByteCounter/StreamingByteCounter.swift index dc340d14b..ecfb48776 100644 --- a/Examples/StreamingByteCounter/StreamingByteCounter.swift +++ b/Examples/StreamingByteCounter/StreamingByteCounter.swift @@ -18,7 +18,7 @@ import NIOCore @main struct StreamingByteCounter { static func main() async throws { - let httpClient = HTTPClient(eventLoopGroupProvider: .createNew) + let httpClient = HTTPClient(eventLoopGroupProvider: .singleton) do { let request = HTTPClientRequest(url: "https://apple.com") let response = try await httpClient.execute(request, timeout: .seconds(30)) diff --git a/NOTICE.txt b/NOTICE.txt index 095a11740..86a969171 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -50,13 +50,13 @@ This product contains a derivation of the Tony Stone's 'process_test_files.rb'. * https://www.apache.org/licenses/LICENSE-2.0 * HOMEPAGE: * https://github.com/tonystone/build-tools/commit/6c417b7569df24597a48a9aa7b505b636e8f73a1 - * https://github.com/tonystone/build-tools/blob/master/source/xctest_tool.rb + * https://github.com/tonystone/build-tools/blob/cf3440f43bde2053430285b4ed0709c865892eb5/source/xctest_tool.rb --- This product contains a derivation of Fabian Fett's 'Base64.swift'. * LICENSE (Apache License 2.0): - * https://github.com/fabianfett/swift-base64-kit/blob/master/LICENSE + * https://github.com/swift-extras/swift-extras-base64/blob/b8af49699d59ad065b801715a5009619100245ca/LICENSE * HOMEPAGE: * https://github.com/fabianfett/swift-base64-kit diff --git a/Package.swift b/Package.swift index dae0d91c7..3294781a9 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version:5.8 +// swift-tools-version:5.10 //===----------------------------------------------------------------------===// // // This source file is part of the AsyncHTTPClient open source project @@ -15,27 +15,44 @@ import PackageDescription +let strictConcurrencyDevelopment = false + +let strictConcurrencySettings: [SwiftSetting] = { + var initialSettings: [SwiftSetting] = [] + initialSettings.append(contentsOf: [ + .enableUpcomingFeature("StrictConcurrency"), + .enableUpcomingFeature("InferSendableFromCaptures"), + ]) + + if strictConcurrencyDevelopment { + // -warnings-as-errors here is a workaround so that IDE-based development can + // get tripped up on -require-explicit-sendable. + initialSettings.append(.unsafeFlags(["-Xfrontend", "-require-explicit-sendable", "-warnings-as-errors"])) + } + + return initialSettings +}() + let package = Package( name: "async-http-client", products: [ - .library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"]), + .library(name: "AsyncHTTPClient", targets: ["AsyncHTTPClient"]) ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"), - .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.22.0"), - .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.19.0"), - .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.13.0"), - .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.19.0"), - .package(url: "https://github.com/apple/swift-log.git", from: "1.4.4"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"), + .package(url: "https://github.com/apple/swift-nio-ssl.git", from: "2.30.0"), + .package(url: "https://github.com/apple/swift-nio-http2.git", from: "1.36.0"), + .package(url: "https://github.com/apple/swift-nio-extras.git", from: "1.26.0"), + .package(url: "https://github.com/apple/swift-nio-transport-services.git", from: "1.24.0"), + .package(url: "https://github.com/apple/swift-log.git", from: "1.6.0"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), - .package(url: "https://github.com/apple/swift-docc-plugin", from: "1.0.0"), - .package(url: "https://github.com/apple/swift-algorithms", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-algorithms.git", from: "1.0.0"), ], targets: [ .target( name: "CAsyncHTTPClient", cSettings: [ - .define("_GNU_SOURCE"), + .define("_GNU_SOURCE") ] ), .target( @@ -56,7 +73,8 @@ let package = Package( .product(name: "Logging", package: "swift-log"), .product(name: "Atomics", package: "swift-atomics"), .product(name: "Algorithms", package: "swift-algorithms"), - ] + ], + swiftSettings: strictConcurrencySettings ), .testTarget( name: "AsyncHTTPClientTests", @@ -80,7 +98,24 @@ let package = Package( .copy("Resources/self_signed_key.pem"), .copy("Resources/example.com.cert.pem"), .copy("Resources/example.com.private-key.pem"), - ] + ], + swiftSettings: strictConcurrencySettings ), ] ) + +// --- STANDARD CROSS-REPO SETTINGS DO NOT EDIT --- // +for target in package.targets { + switch target.type { + case .regular, .test, .executable: + var settings = target.swiftSettings ?? [] + // https://github.com/swiftlang/swift-evolution/blob/main/proposals/0444-member-import-visibility.md + settings.append(.enableUpcomingFeature("MemberImportVisibility")) + target.swiftSettings = settings + case .macro, .plugin, .system, .binary: + () // not applicable + @unknown default: + () // we don't know what to do here, do nothing + } +} +// --- END: STANDARD CROSS-REPO SETTINGS DO NOT EDIT --- // diff --git a/README.md b/README.md index 871eb910b..a4f49c8c8 100644 --- a/README.md +++ b/README.md @@ -306,7 +306,7 @@ Please have a look at [SECURITY.md](SECURITY.md) for AsyncHTTPClient's security ## Supported Versions -The most recent versions of AsyncHTTPClient support Swift 5.6 and newer. The minimum Swift version supported by AsyncHTTPClient releases are detailed below: +The most recent versions of AsyncHTTPClient support Swift 5.10 and newer. The minimum Swift version supported by AsyncHTTPClient releases are detailed below: AsyncHTTPClient | Minimum Swift Version --------------------|---------------------- @@ -316,4 +316,6 @@ AsyncHTTPClient | Minimum Swift Version `1.13.0 ..< 1.18.0` | 5.5.2 `1.18.0 ..< 1.20.0` | 5.6 `1.20.0 ..< 1.21.0` | 5.7 -`1.21.0 ...` | 5.8 +`1.21.0 ..< 1.26.0` | 5.8 +`1.26.0 ..< 1.27.0` | 5.9 +`1.27.0 ...` | 5.10 diff --git a/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift index 8f6b32bd2..fbcc82ec1 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/AnyAsyncSequence.swift @@ -46,3 +46,6 @@ struct AnyAsyncSequence: Sendable, AsyncSequence { .init(nextCallback: self.makeAsyncIteratorCallback()) } } + +@available(*, unavailable) +extension AnyAsyncSequence.AsyncIterator: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift index ef858443e..5fc1be9f5 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClient+execute.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -import struct Foundation.URL import Logging import NIOCore import NIOHTTP1 +import struct Foundation.URL + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClient { /// Execute arbitrary HTTP requests. @@ -25,6 +26,10 @@ extension HTTPClient { /// - request: HTTP request to execute. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. + /// + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. + /// /// - Returns: The response to the request. Note that the `body` of the response may not yet have been fully received. public func execute( _ request: HTTPClientRequest, @@ -50,6 +55,10 @@ extension HTTPClient { /// - request: HTTP request to execute. /// - timeout: time the the request has to complete. /// - logger: The logger to use for this request. + /// + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. + /// /// - Returns: The response to the request. Note that the `body` of the response may not yet have been fully received. public func execute( _ request: HTTPClientRequest, @@ -66,6 +75,8 @@ extension HTTPClient { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClient { + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. private func executeAndFollowRedirectsIfNeeded( _ request: HTTPClientRequest, deadline: NIODeadline, @@ -74,22 +85,42 @@ extension HTTPClient { ) async throws -> HTTPClientResponse { var currentRequest = request var currentRedirectState = redirectState + var history: [HTTPClientRequestResponse] = [] // this loop is there to follow potential redirects while true { let preparedRequest = try HTTPClientRequest.Prepared(currentRequest, dnsOverride: configuration.dnsOverride) - let response = try await self.executeCancellable(preparedRequest, deadline: deadline, logger: logger) + let response = try await { + var response = try await self.executeCancellable(preparedRequest, deadline: deadline, logger: logger) + + history.append( + .init( + request: currentRequest, + responseHead: .init( + version: response.version, + status: response.status, + headers: response.headers + ) + ) + ) + + response.history = history + + return response + }() guard var redirectState = currentRedirectState else { // a `nil` redirectState means we should not follow redirects return response } - guard let redirectURL = response.headers.extractRedirectTarget( - status: response.status, - originalURL: preparedRequest.url, - originalScheme: preparedRequest.poolKey.scheme - ) else { + guard + let redirectURL = response.headers.extractRedirectTarget( + status: response.status, + originalURL: preparedRequest.url, + originalScheme: preparedRequest.poolKey.scheme + ) + else { // response does not want a redirect return response } @@ -113,6 +144,8 @@ extension HTTPClient { } } + /// - warning: This method may violates Structured Concurrency because it returns a `HTTPClientResponse` that needs to be + /// streamed by the user. This means the request, the connection and other resources are still alive when the request returns. private func executeCancellable( _ request: HTTPClientRequest.Prepared, deadline: NIODeadline, @@ -120,31 +153,35 @@ extension HTTPClient { ) async throws -> HTTPClientResponse { let cancelHandler = TransactionCancelHandler() - return try await withTaskCancellationHandler(operation: { () async throws -> HTTPClientResponse in - let eventLoop = self.eventLoopGroup.any() - let deadlineTask = eventLoop.scheduleTask(deadline: deadline) { - cancelHandler.cancel(reason: .deadlineExceeded) + return try await withTaskCancellationHandler( + operation: { () async throws -> HTTPClientResponse in + let eventLoop = self.eventLoopGroup.any() + let deadlineTask = eventLoop.scheduleTask(deadline: deadline) { + cancelHandler.cancel(reason: .deadlineExceeded) + } + defer { + deadlineTask.cancel() + } + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) -> Void in + let transaction = Transaction( + request: request, + requestOptions: .fromClientConfiguration(self.configuration), + logger: logger, + connectionDeadline: .now() + (self.configuration.timeout.connectionCreationTimeout), + preferredEventLoop: eventLoop, + responseContinuation: continuation + ) + + cancelHandler.registerTransaction(transaction) + + self.poolManager.executeRequest(transaction) + } + }, + onCancel: { + cancelHandler.cancel(reason: .taskCanceled) } - defer { - deadlineTask.cancel() - } - return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) -> Void in - let transaction = Transaction( - request: request, - requestOptions: .fromClientConfiguration(self.configuration), - logger: logger, - connectionDeadline: .now() + (self.configuration.timeout.connectionCreationTimeout), - preferredEventLoop: eventLoop, - responseContinuation: continuation - ) - - cancelHandler.registerTransaction(transaction) - - self.poolManager.executeRequest(transaction) - } - }, onCancel: { - cancelHandler.cancel(reason: .taskCanceled) - }) + ) } } diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift index 360e91b89..c39452897 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+Prepared.swift @@ -12,18 +12,19 @@ // //===----------------------------------------------------------------------===// -import struct Foundation.URL import NIOCore import NIOHTTP1 import NIOSSL +import struct Foundation.URL + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest { struct Prepared { enum Body { case asyncSequence( length: RequestBodyLength, - nextBodyPart: (ByteBufferAllocator) async throws -> ByteBuffer? + makeAsyncIterator: @Sendable () -> ((ByteBufferAllocator) async throws -> ByteBuffer?) ) case sequence( length: RequestBodyLength, @@ -45,7 +46,7 @@ extension HTTPClientRequest { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientRequest.Prepared { init(_ request: HTTPClientRequest, dnsOverride: [String: String] = [:]) throws { - guard let url = URL(string: request.url) else { + guard !request.url.isEmpty, let url = URL(string: request.url) else { throw HTTPClientError.invalidURL } @@ -79,9 +80,13 @@ extension HTTPClientRequest.Prepared.Body { init(_ body: HTTPClientRequest.Body) { switch body.mode { case .asyncSequence(let length, let makeAsyncIterator): - self = .asyncSequence(length: length, nextBodyPart: makeAsyncIterator()) + self = .asyncSequence(length: length, makeAsyncIterator: makeAsyncIterator) case .sequence(let length, let canBeConsumedMultipleTimes, let makeCompleteBody): - self = .sequence(length: length, canBeConsumedMultipleTimes: canBeConsumedMultipleTimes, makeCompleteBody: makeCompleteBody) + self = .sequence( + length: length, + canBeConsumedMultipleTimes: canBeConsumedMultipleTimes, + makeCompleteBody: makeCompleteBody + ) case .byteBuffer(let byteBuffer): self = .byteBuffer(byteBuffer) } @@ -95,7 +100,7 @@ extension RequestBodyLength { case .none: self = .known(0) case .byteBuffer(let buffer): - self = .known(buffer.readableBytes) + self = .known(Int64(buffer.readableBytes)) case .sequence(let length, _, _), .asyncSequence(let length, _): self = length } diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+auth.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+auth.swift new file mode 100644 index 000000000..106a8f76b --- /dev/null +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest+auth.swift @@ -0,0 +1,27 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2024 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClientRequest { + /// Set basic auth for a request. + /// + /// - parameters: + /// - username: the username to authenticate with + /// - password: authentication password associated with the username + public mutating func setBasicAuth(username: String, password: String) { + self.headers.setBasicAuth(username: username, password: password) + } +} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift index 4ed79e38c..dca7de0ef 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientRequest.swift @@ -125,7 +125,7 @@ extension HTTPClientRequest.Body { public static func bytes( _ bytes: Bytes ) -> Self where Bytes.Element == UInt8 { - self.bytes(bytes, length: .known(bytes.count)) + self.bytes(bytes, length: .known(Int64(bytes.count))) } /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `Sequence` of bytes. @@ -140,7 +140,7 @@ extension HTTPClientRequest.Body { /// /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload - /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)`` will use `Content-Length`. + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. /// /// - parameters: /// - bytes: The bytes of the request body. @@ -176,23 +176,29 @@ extension HTTPClientRequest.Body { // the maximum size of a ByteBuffer. if bufferPointer.count <= byteBufferMaxSize { let buffer = ByteBuffer(bytes: bufferPointer) - return Self(.sequence( - length: length.storage, - canBeConsumedMultipleTimes: true, - makeCompleteBody: { _ in buffer } - )) + return Self( + .sequence( + length: length.storage, + canBeConsumedMultipleTimes: true, + makeCompleteBody: { _ in buffer } + ) + ) } else { // we need to copy `bufferPointer` eagerly as the pointer is only valid during the call to `withContiguousStorageIfAvailable` - let buffers: Array = bufferPointer.chunks(ofCount: byteBufferMaxSize).map { ByteBuffer(bytes: $0) } - return Self(.asyncSequence( - length: length.storage, - makeAsyncIterator: { - var iterator = buffers.makeIterator() - return { _ in - iterator.next() + let buffers: [ByteBuffer] = bufferPointer.chunks(ofCount: byteBufferMaxSize).map { + ByteBuffer(bytes: $0) + } + return Self( + .asyncSequence( + length: length.storage, + makeAsyncIterator: { + var iterator = buffers.makeIterator() + return { _ in + iterator.next() + } } - } - )) + ) + ) } } if let body = body { @@ -200,21 +206,23 @@ extension HTTPClientRequest.Body { } // slow path - return Self(.asyncSequence( - length: length.storage - ) { - var iterator = bytes.makeIterator() - return { allocator in - var buffer = allocator.buffer(capacity: bagOfBytesToByteBufferConversionChunkSize) - while buffer.writableBytes > 0, let byte = iterator.next() { - buffer.writeInteger(byte) - } - if buffer.readableBytes > 0 { - return buffer + return Self( + .asyncSequence( + length: length.storage + ) { + var iterator = bytes.makeIterator() + return { allocator in + var buffer = allocator.buffer(capacity: bagOfBytesToByteBufferConversionChunkSize) + while buffer.writableBytes > 0, let byte = iterator.next() { + buffer.writeInteger(byte) + } + if buffer.readableBytes > 0 { + return buffer + } + return nil } - return nil } - }) + ) } /// Create an ``HTTPClientRequest/Body-swift.struct`` from a `Collection` of bytes. @@ -225,7 +233,7 @@ extension HTTPClientRequest.Body { /// /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload - /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)`` will use `Content-Length`. + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. /// /// - parameters: /// - bytes: The bytes of the request body. @@ -237,25 +245,29 @@ extension HTTPClientRequest.Body { length: Length ) -> Self where Bytes.Element == UInt8 { if bytes.count <= bagOfBytesToByteBufferConversionChunkSize { - return self.init(.sequence( - length: length.storage, - canBeConsumedMultipleTimes: true - ) { allocator in - allocator.buffer(bytes: bytes) - }) + return self.init( + .sequence( + length: length.storage, + canBeConsumedMultipleTimes: true + ) { allocator in + allocator.buffer(bytes: bytes) + } + ) } else { - return self.init(.asyncSequence( - length: length.storage, - makeAsyncIterator: { - var iterator = bytes.chunks(ofCount: bagOfBytesToByteBufferConversionChunkSize).makeIterator() - return { allocator in - guard let chunk = iterator.next() else { - return nil + return self.init( + .asyncSequence( + length: length.storage, + makeAsyncIterator: { + var iterator = bytes.chunks(ofCount: bagOfBytesToByteBufferConversionChunkSize).makeIterator() + return { allocator in + guard let chunk = iterator.next() else { + return nil + } + return allocator.buffer(bytes: chunk) } - return allocator.buffer(bytes: chunk) } - } - )) + ) + ) } } @@ -265,7 +277,7 @@ extension HTTPClientRequest.Body { /// /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload - /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)`` will use `Content-Length`. + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. /// /// - parameters: /// - sequenceOfBytes: The bytes of the request body. @@ -276,12 +288,14 @@ extension HTTPClientRequest.Body { _ sequenceOfBytes: SequenceOfBytes, length: Length ) -> Self where SequenceOfBytes.Element == ByteBuffer { - let body = self.init(.asyncSequence(length: length.storage) { - var iterator = sequenceOfBytes.makeAsyncIterator() - return { _ -> ByteBuffer? in - try await iterator.next() + let body = self.init( + .asyncSequence(length: length.storage) { + var iterator = sequenceOfBytes.makeAsyncIterator() + return { _ -> ByteBuffer? in + try await iterator.next() + } } - }) + ) return body } @@ -293,7 +307,7 @@ extension HTTPClientRequest.Body { /// /// Caution should be taken with this method to ensure that the `length` is correct. Incorrect lengths /// will cause unnecessary runtime failures. Setting `length` to ``Length/unknown`` will trigger the upload - /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)`` will use `Content-Length`. + /// to use `chunked` `Transfer-Encoding`, while using ``Length/known(_:)-9q0ge`` will use `Content-Length`. /// /// - parameters: /// - bytes: The bytes of the request body. @@ -304,19 +318,21 @@ extension HTTPClientRequest.Body { _ bytes: Bytes, length: Length ) -> Self where Bytes.Element == UInt8 { - let body = self.init(.asyncSequence(length: length.storage) { - var iterator = bytes.makeAsyncIterator() - return { allocator -> ByteBuffer? in - var buffer = allocator.buffer(capacity: bagOfBytesToByteBufferConversionChunkSize) - while buffer.writableBytes > 0, let byte = try await iterator.next() { - buffer.writeInteger(byte) - } - if buffer.readableBytes > 0 { - return buffer + let body = self.init( + .asyncSequence(length: length.storage) { + var iterator = bytes.makeAsyncIterator() + return { allocator -> ByteBuffer? in + var buffer = allocator.buffer(capacity: bagOfBytesToByteBufferConversionChunkSize) + while buffer.writableBytes > 0, let byte = try await iterator.next() { + buffer.writeInteger(byte) + } + if buffer.readableBytes > 0 { + return buffer + } + return nil } - return nil } - }) + ) return body } } @@ -341,7 +357,13 @@ extension HTTPClientRequest.Body { public static let unknown: Self = .init(storage: .unknown) /// The size of the request body is known and exactly `count` bytes + @available(*, deprecated, message: "Use `known(_ count: Int64)` with an explicit Int64 argument instead") public static func known(_ count: Int) -> Self { + .init(storage: .known(Int64(count))) + } + + /// The size of the request body is known and exactly `count` bytes + public static func known(_ count: Int64) -> Self { .init(storage: .known(count)) } @@ -399,3 +421,9 @@ extension HTTPClientRequest.Body { } } } + +@available(*, unavailable) +extension HTTPClientRequest.Body.AsyncIterator: Sendable {} + +@available(*, unavailable) +extension HTTPClientRequest.Body.AsyncIterator.Storage: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift index ee7f11592..36c1cb36f 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/HTTPClientResponse.swift @@ -15,6 +15,8 @@ import NIOCore import NIOHTTP1 +import struct Foundation.URL + /// A representation of an HTTP response for the Swift Concurrency HTTPClient API. /// /// This object is similar to ``HTTPClient/Response``, but used for the Swift Concurrency API. @@ -32,6 +34,18 @@ public struct HTTPClientResponse: Sendable { /// The body of this HTTP response. public var body: Body + /// The history of all requests and responses in redirect order. + public var history: [HTTPClientRequestResponse] + + /// The target URL (after redirects) of the response. + public var url: URL? { + guard let lastRequestURL = self.history.last?.request.url else { + return nil + } + + return URL(string: lastRequestURL) + } + @inlinable public init( version: HTTPVersion = .http1_1, status: HTTPResponseStatus = .ok, @@ -42,6 +56,21 @@ public struct HTTPClientResponse: Sendable { self.status = status self.headers = headers self.body = body + self.history = [] + } + + @inlinable public init( + version: HTTPVersion = .http1_1, + status: HTTPResponseStatus = .ok, + headers: HTTPHeaders = [:], + body: Body = Body(), + history: [HTTPClientRequestResponse] = [] + ) { + self.version = version + self.status = status + self.headers = headers + self.body = body + self.history = history } init( @@ -49,24 +78,39 @@ public struct HTTPClientResponse: Sendable { version: HTTPVersion, status: HTTPResponseStatus, headers: HTTPHeaders, - body: TransactionBody + body: TransactionBody, + history: [HTTPClientRequestResponse] ) { self.init( version: version, status: status, headers: headers, - body: .init(.transaction( - body, - expectedContentLength: HTTPClientResponse.expectedContentLength( - requestMethod: requestMethod, - headers: headers, - status: status + body: .init( + .transaction( + body, + expectedContentLength: HTTPClientResponse.expectedContentLength( + requestMethod: requestMethod, + headers: headers, + status: status + ) ) - )) + ), + history: history ) } } +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +public struct HTTPClientRequestResponse: Sendable { + public var request: HTTPClientRequest + public var responseHead: HTTPResponseHead + + public init(request: HTTPClientRequest, responseHead: HTTPResponseHead) { + self.request = request + self.responseHead = responseHead + } +} + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientResponse { /// A representation of the response body for an HTTP response. @@ -108,7 +152,7 @@ extension HTTPClientResponse { case .transaction(_, let expectedContentLength): if let contentLength = expectedContentLength { if contentLength > maxBytes { - throw NIOTooManyBytesError() + throw NIOTooManyBytesError(maxBytes: maxBytes) } } case .anyAsyncSequence: @@ -116,7 +160,8 @@ extension HTTPClientResponse { } /// calling collect function within here in order to ensure the correct nested type - func collect(_ body: Body, maxBytes: Int) async throws -> ByteBuffer where Body.Element == ByteBuffer { + func collect(_ body: Body, maxBytes: Int) async throws -> ByteBuffer + where Body.Element == ByteBuffer { try await body.collect(upTo: maxBytes) } return try await collect(self, maxBytes: maxBytes) @@ -126,7 +171,11 @@ extension HTTPClientResponse { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension HTTPClientResponse { - static func expectedContentLength(requestMethod: HTTPMethod, headers: HTTPHeaders, status: HTTPResponseStatus) -> Int? { + static func expectedContentLength( + requestMethod: HTTPMethod, + headers: HTTPHeaders, + status: HTTPResponseStatus + ) -> Int? { if status == .notModified { return 0 } else if requestMethod == .HEAD { @@ -210,3 +259,9 @@ extension HTTPClientResponse.Body { .stream(CollectionOfOne(byteBuffer).async) } } + +@available(*, unavailable) +extension HTTPClientResponse.Body.AsyncIterator: Sendable {} + +@available(*, unavailable) +extension HTTPClientResponse.Body.Storage.AsyncIterator: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift index ad49332c0..457627a8a 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction+StateMachine.swift @@ -34,14 +34,14 @@ extension Transaction { case finished(error: Error?) } - fileprivate enum RequestStreamState { + fileprivate enum RequestStreamState: Sendable { case requestHeadSent case producing case paused(continuation: CheckedContinuation?) case finished } - fileprivate enum ResponseStreamState { + fileprivate enum ResponseStreamState: Sendable { // Waiting for response head. Valid transitions to: streamingBody. case waitingForResponseHead // streaming response body. Valid transitions to: finished. @@ -82,9 +82,20 @@ extension Transaction { enum FailAction { case none /// fail response before head received. scheduler and executor are exclusive here. - case failResponseHead(CheckedContinuation, Error, HTTPRequestScheduler?, HTTPRequestExecutor?, bodyStreamContinuation: CheckedContinuation?) + case failResponseHead( + CheckedContinuation, + Error, + HTTPRequestScheduler?, + HTTPRequestExecutor?, + bodyStreamContinuation: CheckedContinuation? + ) /// fail response after response head received. fail the response stream (aka call to `next()`) - case failResponseStream(TransactionBody.Source, Error, HTTPRequestExecutor, bodyStreamContinuation: CheckedContinuation?) + case failResponseStream( + TransactionBody.Source, + Error, + HTTPRequestExecutor, + bodyStreamContinuation: CheckedContinuation? + ) case failRequestStreamContinuation(CheckedContinuation, Error) } @@ -116,24 +127,41 @@ extension Transaction { switch requestStreamState { case .paused(continuation: .some(let continuation)): self.state = .finished(error: error) - return .failResponseHead(context.continuation, error, nil, context.executor, bodyStreamContinuation: continuation) + return .failResponseHead( + context.continuation, + error, + nil, + context.executor, + bodyStreamContinuation: continuation + ) case .requestHeadSent, .finished, .producing, .paused(continuation: .none): self.state = .finished(error: error) - return .failResponseHead(context.continuation, error, nil, context.executor, bodyStreamContinuation: nil) + return .failResponseHead( + context.continuation, + error, + nil, + context.executor, + bodyStreamContinuation: nil + ) } case .executing(let context, let requestStreamState, .streamingBody(let source)): self.state = .finished(error: error) switch requestStreamState { case .paused(let bodyStreamContinuation): - return .failResponseStream(source, error, context.executor, bodyStreamContinuation: bodyStreamContinuation) + return .failResponseStream( + source, + error, + context.executor, + bodyStreamContinuation: bodyStreamContinuation + ) case .finished, .producing, .requestHeadSent: return .failResponseStream(source, error, context.executor, bodyStreamContinuation: nil) } case .finished(error: _), - .executing(_, _, .finished): + .executing(_, _, .finished): return .none } } @@ -165,7 +193,7 @@ extension Transaction { return .cancel(executor) case .executing, - .finished(error: .none): + .finished(error: .none): preconditionFailure("Invalid state: \(self.state)") } } @@ -179,7 +207,9 @@ extension Transaction { mutating func resumeRequestBodyStream() -> ResumeProducingAction { switch self.state { case .initialized, .queued, .deadlineExceededWhileQueued: - preconditionFailure("Received a resumeBodyRequest on a request, that isn't executing. Invalid state: \(self.state)") + preconditionFailure( + "Received a resumeBodyRequest on a request, that isn't executing. Invalid state: \(self.state)" + ) case .executing(let context, .requestHeadSent, let responseState): // the request can start to send its body. @@ -187,7 +217,9 @@ extension Transaction { return .startStream(context.allocator) case .executing(_, .producing, _): - preconditionFailure("Received a resumeBodyRequest on a request, that is producing. Invalid state: \(self.state)") + preconditionFailure( + "Received a resumeBodyRequest on a request, that is producing. Invalid state: \(self.state)" + ) case .executing(let context, .paused(.none), let responseState): // request stream is currently paused, but there is no write waiting. We don't need @@ -213,17 +245,17 @@ extension Transaction { mutating func pauseRequestBodyStream() { switch self.state { case .initialized, - .queued, - .deadlineExceededWhileQueued, - .executing(_, .requestHeadSent, _): + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _): preconditionFailure("A request stream can only be resumed, if the request was started") case .executing(let context, .producing, let responseSteam): self.state = .executing(context, .paused(continuation: nil), responseSteam) case .executing(_, .paused, _), - .executing(_, .finished, _), - .finished: + .executing(_, .finished, _), + .finished: // the channels writability changed to paused after we have already forwarded all // request bytes. Can be ignored. break @@ -239,10 +271,12 @@ extension Transaction { func writeNextRequestPart() -> NextWriteAction { switch self.state { case .initialized, - .queued, - .deadlineExceededWhileQueued, - .executing(_, .requestHeadSent, _): - preconditionFailure("A request stream can only produce, if the request was started. Invalid state: \(self.state)") + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _): + preconditionFailure( + "A request stream can only produce, if the request was started. Invalid state: \(self.state)" + ) case .executing(let context, .producing, _): // We are currently producing the request body. The executors channel is writable. @@ -260,7 +294,9 @@ extension Transaction { return .writeAndWait(context.executor) case .executing(_, .paused(continuation: .some), _): - preconditionFailure("A write continuation already exists, but we tried to set another one. Invalid state: \(self.state)") + preconditionFailure( + "A write continuation already exists, but we tried to set another one. Invalid state: \(self.state)" + ) case .finished, .executing(_, .finished, _): return .fail @@ -270,11 +306,13 @@ extension Transaction { mutating func waitForRequestBodyDemand(continuation: CheckedContinuation) { switch self.state { case .initialized, - .queued, - .deadlineExceededWhileQueued, - .executing(_, .requestHeadSent, _), - .executing(_, .finished, _): - preconditionFailure("A request stream can only produce, if the request was started. Invalid state: \(self.state)") + .queued, + .deadlineExceededWhileQueued, + .executing(_, .requestHeadSent, _), + .executing(_, .finished, _): + preconditionFailure( + "A request stream can only produce, if the request was started. Invalid state: \(self.state)" + ) case .executing(_, .producing, _): preconditionFailure() @@ -303,17 +341,19 @@ extension Transaction { mutating func finishRequestBodyStream() -> FinishAction { switch self.state { case .initialized, - .queued, - .deadlineExceededWhileQueued, - .executing(_, .finished, _): + .queued, + .deadlineExceededWhileQueued, + .executing(_, .finished, _): preconditionFailure("Invalid state: \(self.state)") case .executing(_, .paused(continuation: .some), _): - preconditionFailure("Received a request body end, while having a registered back-pressure continuation. Invalid state: \(self.state)") + preconditionFailure( + "Received a request body end, while having a registered back-pressure continuation. Invalid state: \(self.state)" + ) case .executing(let context, .producing, let responseState), - .executing(let context, .paused(continuation: .none), let responseState), - .executing(let context, .requestHeadSent, let responseState): + .executing(let context, .paused(continuation: .none), let responseState), + .executing(let context, .requestHeadSent, let responseState): switch responseState { case .finished: @@ -345,10 +385,10 @@ extension Transaction { ) -> ReceiveResponseHeadAction { switch self.state { case .initialized, - .queued, - .deadlineExceededWhileQueued, - .executing(_, _, .streamingBody), - .executing(_, _, .finished): + .queued, + .deadlineExceededWhileQueued, + .executing(_, _, .streamingBody), + .executing(_, _, .finished): preconditionFailure("invalid state \(self.state)") case .executing(let context, let requestState, .waitingForResponseHead): @@ -381,15 +421,15 @@ extension Transaction { mutating func produceMore() -> ProduceMoreAction { switch self.state { case .initialized, - .queued, - .deadlineExceededWhileQueued, - .executing(_, _, .waitingForResponseHead): + .queued, + .deadlineExceededWhileQueued, + .executing(_, _, .waitingForResponseHead): preconditionFailure("invalid state \(self.state)") case .executing(let context, _, .streamingBody): return .requestMoreResponseBodyParts(context.executor) case .finished, - .executing(_, _, .finished): + .executing(_, _, .finished): return .none } } @@ -402,7 +442,9 @@ extension Transaction { mutating func receiveResponseBodyParts(_ buffer: CircularBuffer) -> ReceiveResponsePartAction { switch self.state { case .initialized, .queued, .deadlineExceededWhileQueued: - preconditionFailure("Received a response body part, but request hasn't started yet. Invalid state: \(self.state)") + preconditionFailure( + "Received a response body part, but request hasn't started yet. Invalid state: \(self.state)" + ) case .executing(_, _, .waitingForResponseHead): preconditionFailure("If we receive a response body, we must have received a head before") @@ -415,7 +457,9 @@ extension Transaction { return .none case .executing(_, _, .finished): - preconditionFailure("Received response end. Must not receive further body parts after that. Invalid state: \(self.state)") + preconditionFailure( + "Received response end. Must not receive further body parts after that. Invalid state: \(self.state)" + ) } } @@ -427,10 +471,12 @@ extension Transaction { mutating func succeedRequest(_ newChunks: CircularBuffer?) -> ReceiveResponseEndAction { switch self.state { case .initialized, - .queued, - .deadlineExceededWhileQueued, - .executing(_, _, .waitingForResponseHead): - preconditionFailure("Received no response head, but received a response end. Invalid state: \(self.state)") + .queued, + .deadlineExceededWhileQueued, + .executing(_, _, .waitingForResponseHead): + preconditionFailure( + "Received no response head, but received a response end. Invalid state: \(self.state)" + ) case .executing(let context, let requestState, .streamingBody(let source)): self.state = .executing(context, requestState, .finished) @@ -439,7 +485,9 @@ extension Transaction { // the request failed or was cancelled before, we can ignore all events return .none case .executing(_, _, .finished): - preconditionFailure("Already received an eof or error before. Must not receive further events. Invalid state: \(self.state)") + preconditionFailure( + "Already received an eof or error before. Must not receive further events. Invalid state: \(self.state)" + ) } } @@ -504,3 +552,6 @@ extension Transaction { } } } + +@available(*, unavailable) +extension Transaction.StateMachine: Sendable {} diff --git a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift index 6d9192642..6bf8b38b7 100644 --- a/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift +++ b/Sources/AsyncHTTPClient/AsyncAwait/Transaction.swift @@ -19,7 +19,11 @@ import NIOHTTP1 import NIOSSL @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -@usableFromInline final class Transaction: @unchecked Sendable { +@usableFromInline +final class Transaction: + // until NIOLockedValueBox learns `sending` because StateMachine cannot be Sendable + @unchecked Sendable +{ let logger: Logger let request: HTTPClientRequest.Prepared @@ -28,8 +32,7 @@ import NIOSSL let preferredEventLoop: EventLoop let requestOptions: RequestOptions - private let stateLock = NIOLock() - private var state: StateMachine + private let state: NIOLockedValueBox init( request: HTTPClientRequest.Prepared, @@ -44,7 +47,7 @@ import NIOSSL self.logger = logger self.connectionDeadline = connectionDeadline self.preferredEventLoop = preferredEventLoop - self.state = StateMachine(responseContinuation) + self.state = NIOLockedValueBox(StateMachine(responseContinuation)) } func cancel() { @@ -56,8 +59,8 @@ import NIOSSL private func writeOnceAndOneTimeOnly(byteBuffer: ByteBuffer) { // This method is synchronously invoked after sending the request head. For this reason we // can make a number of assumptions, how the state machine will react. - let writeAction = self.stateLock.withLock { - self.state.writeNextRequestPart() + let writeAction = self.state.withLockedValue { state in + state.writeNextRequestPart() } switch writeAction { @@ -74,9 +77,11 @@ import NIOSSL private func continueRequestBodyStream( _ allocator: ByteBufferAllocator, - next: @escaping ((ByteBufferAllocator) async throws -> ByteBuffer?) + makeAsyncIterator: @Sendable @escaping () -> ((ByteBufferAllocator) async throws -> ByteBuffer?) ) { Task { + let next = makeAsyncIterator() + do { while let part = try await next(allocator) { do { @@ -99,30 +104,33 @@ import NIOSSL struct BreakTheWriteLoopError: Swift.Error {} + // FIXME: Refactor this to not use `self.state.unsafe`. private func writeRequestBodyPart(_ part: ByteBuffer) async throws { - self.stateLock.lock() - switch self.state.writeNextRequestPart() { + self.state.unsafe.lock() + switch self.state.unsafe.withValueAssumingLockIsAcquired({ state in state.writeNextRequestPart() }) { case .writeAndContinue(let executor): - self.stateLock.unlock() + self.state.unsafe.unlock() executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil) case .writeAndWait(let executor): try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in - self.state.waitForRequestBodyDemand(continuation: continuation) - self.stateLock.unlock() + self.state.unsafe.withValueAssumingLockIsAcquired({ state in + state.waitForRequestBodyDemand(continuation: continuation) + }) + self.state.unsafe.unlock() executor.writeRequestBodyPart(.byteBuffer(part), request: self, promise: nil) } case .fail: - self.stateLock.unlock() + self.state.unsafe.unlock() throw BreakTheWriteLoopError() } } private func requestBodyStreamFinished() { - let finishAction = self.stateLock.withLock { - self.state.finishRequestBodyStream() + let finishAction = self.state.withLockedValue { state in + state.finishRequestBodyStream() } switch finishAction { @@ -146,12 +154,12 @@ import NIOSSL @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction: HTTPSchedulableRequest { var poolKey: ConnectionPool.Key { self.request.poolKey } - var tlsConfiguration: TLSConfiguration? { return self.request.tlsConfiguration } - var requiredEventLoop: EventLoop? { return nil } + var tlsConfiguration: TLSConfiguration? { self.request.tlsConfiguration } + var requiredEventLoop: EventLoop? { nil } func requestWasQueued(_ scheduler: HTTPRequestScheduler) { - self.stateLock.withLock { - self.state.requestWasQueued(scheduler) + self.state.withLockedValue { state in + state.requestWasQueued(scheduler) } } } @@ -165,8 +173,8 @@ extension Transaction: HTTPExecutableRequest { // MARK: Request func willExecuteRequest(_ executor: HTTPRequestExecutor) { - let action = self.stateLock.withLock { - self.state.willExecuteRequest(executor) + let action = self.state.withLockedValue { state in + state.willExecuteRequest(executor) } switch action { @@ -183,8 +191,8 @@ extension Transaction: HTTPExecutableRequest { func requestHeadSent() {} func resumeRequestBodyStream() { - let action = self.stateLock.withLock { - self.state.resumeRequestBodyStream() + let action = self.state.withLockedValue { state in + state.resumeRequestBodyStream() } switch action { @@ -193,9 +201,9 @@ extension Transaction: HTTPExecutableRequest { case .startStream(let allocator): switch self.request.body { - case .asyncSequence(_, let next): + case .asyncSequence(_, let makeAsyncIterator): // it is safe to call this async here. it dispatches... - self.continueRequestBodyStream(allocator, next: next) + self.continueRequestBodyStream(allocator, makeAsyncIterator: makeAsyncIterator) case .byteBuffer(let byteBuffer): self.writeOnceAndOneTimeOnly(byteBuffer: byteBuffer) @@ -214,16 +222,16 @@ extension Transaction: HTTPExecutableRequest { } func pauseRequestBodyStream() { - self.stateLock.withLock { - self.state.pauseRequestBodyStream() + self.state.withLockedValue { state in + state.pauseRequestBodyStream() } } // MARK: Response func receiveResponseHead(_ head: HTTPResponseHead) { - let action = self.stateLock.withLock { - self.state.receiveResponseHead(head, delegate: self) + let action = self.state.withLockedValue { state in + state.receiveResponseHead(head, delegate: self) } switch action { @@ -236,15 +244,16 @@ extension Transaction: HTTPExecutableRequest { version: head.version, status: head.status, headers: head.headers, - body: body + body: body, + history: [] ) continuation.resume(returning: response) } } func receiveResponseBodyParts(_ buffer: CircularBuffer) { - let action = self.stateLock.withLock { - self.state.receiveResponseBodyParts(buffer) + let action = self.state.withLockedValue { state in + state.receiveResponseBodyParts(buffer) } switch action { case .none: @@ -260,8 +269,8 @@ extension Transaction: HTTPExecutableRequest { } func succeedRequest(_ buffer: CircularBuffer?) { - let succeedAction = self.stateLock.withLock { - self.state.succeedRequest(buffer) + let succeedAction = self.state.withLockedValue { state in + state.succeedRequest(buffer) } switch succeedAction { case .finishResponseStream(let source, let finalResponse): @@ -276,8 +285,8 @@ extension Transaction: HTTPExecutableRequest { } func fail(_ error: Error) { - let action = self.stateLock.withLock { - self.state.fail(error) + let action = self.state.withLockedValue { state in + state.fail(error) } self.performFailAction(action) } @@ -290,7 +299,7 @@ extension Transaction: HTTPExecutableRequest { case .failResponseHead(let continuation, let error, let scheduler, let executor, let bodyStreamContinuation): continuation.resume(throwing: error) bodyStreamContinuation?.resume(throwing: error) - scheduler?.cancelRequest(self) // NOTE: scheduler and executor are exclusive here + scheduler?.cancelRequest(self) // NOTE: scheduler and executor are exclusive here executor?.cancelRequest(self) case .failResponseStream(let source, let error, let executor, let requestBodyStreamContinuation): @@ -304,8 +313,8 @@ extension Transaction: HTTPExecutableRequest { } func deadlineExceeded() { - let action = self.stateLock.withLock { - self.state.deadlineExceeded() + let action = self.state.withLockedValue { state in + state.deadlineExceeded() } self.performDeadlineExceededAction(action) } @@ -317,7 +326,7 @@ extension Transaction: HTTPExecutableRequest { scheduler?.cancelRequest(self) executor?.cancelRequest(self) bodyStreamContinuation?.resume(throwing: HTTPClientError.deadlineExceeded) - case .cancelSchedulerOnly(scheduler: let scheduler): + case .cancelSchedulerOnly(let scheduler): scheduler.cancelRequest(self) case .none: break @@ -329,8 +338,8 @@ extension Transaction: HTTPExecutableRequest { extension Transaction: NIOAsyncSequenceProducerDelegate { @usableFromInline func produceMore() { - let action = self.stateLock.withLock { - self.state.produceMore() + let action = self.state.withLockedValue { state in + state.produceMore() } switch action { case .none: diff --git a/Sources/AsyncHTTPClient/Base64.swift b/Sources/AsyncHTTPClient/Base64.swift index eed511a8c..4d2ddcc49 100644 --- a/Sources/AsyncHTTPClient/Base64.swift +++ b/Sources/AsyncHTTPClient/Base64.swift @@ -19,142 +19,156 @@ extension String { - /// Base64 encode a collection of UInt8 to a string, without the use of Foundation. - @inlinable - init(base64Encoding bytes: Buffer) - where Buffer.Element == UInt8 - { - self = Base64.encode(bytes: bytes) - } + /// Base64 encode a collection of UInt8 to a string, without the use of Foundation. + @inlinable + init(base64Encoding bytes: Buffer) + where Buffer.Element == UInt8 { + self = Base64.encode(bytes: bytes) + } } +// swift-format-ignore: DontRepeatTypeInStaticProperties @usableFromInline -internal struct Base64 { - - @inlinable - static func encode(bytes: Buffer) - -> String where Buffer.Element == UInt8 - { - guard !bytes.isEmpty else { - return "" - } - // In Base64, 3 bytes become 4 output characters, and we pad to the - // nearest multiple of four. - let base64StringLength = ((bytes.count + 2) / 3) * 4 - let alphabet = Base64.encodeBase64 - - return String(customUnsafeUninitializedCapacity: base64StringLength) { backingStorage in - var input = bytes.makeIterator() - var offset = 0 - while let firstByte = input.next() { - let secondByte = input.next() - let thirdByte = input.next() - - backingStorage[offset] = Base64.encode(alphabet: alphabet, firstByte: firstByte) - backingStorage[offset + 1] = Base64.encode(alphabet: alphabet, firstByte: firstByte, secondByte: secondByte) - backingStorage[offset + 2] = Base64.encode(alphabet: alphabet, secondByte: secondByte, thirdByte: thirdByte) - backingStorage[offset + 3] = Base64.encode(alphabet: alphabet, thirdByte: thirdByte) - offset += 4 - } - return offset +internal struct Base64: Sendable { + + @inlinable + static func encode( + bytes: Buffer + ) + -> String where Buffer.Element == UInt8 + { + guard !bytes.isEmpty else { + return "" + } + // In Base64, 3 bytes become 4 output characters, and we pad to the + // nearest multiple of four. + let base64StringLength = ((bytes.count + 2) / 3) * 4 + let alphabet = Base64.encodeBase64 + + return String(customUnsafeUninitializedCapacity: base64StringLength) { backingStorage in + var input = bytes.makeIterator() + var offset = 0 + while let firstByte = input.next() { + let secondByte = input.next() + let thirdByte = input.next() + + backingStorage[offset] = Base64.encode(alphabet: alphabet, firstByte: firstByte) + backingStorage[offset + 1] = Base64.encode( + alphabet: alphabet, + firstByte: firstByte, + secondByte: secondByte + ) + backingStorage[offset + 2] = Base64.encode( + alphabet: alphabet, + secondByte: secondByte, + thirdByte: thirdByte + ) + backingStorage[offset + 3] = Base64.encode(alphabet: alphabet, thirdByte: thirdByte) + offset += 4 + } + return offset + } } - } - - // MARK: Internal - - // The base64 unicode table. - @usableFromInline - static let encodeBase64: [UInt8] = [ - UInt8(ascii: "A"), UInt8(ascii: "B"), UInt8(ascii: "C"), UInt8(ascii: "D"), - UInt8(ascii: "E"), UInt8(ascii: "F"), UInt8(ascii: "G"), UInt8(ascii: "H"), - UInt8(ascii: "I"), UInt8(ascii: "J"), UInt8(ascii: "K"), UInt8(ascii: "L"), - UInt8(ascii: "M"), UInt8(ascii: "N"), UInt8(ascii: "O"), UInt8(ascii: "P"), - UInt8(ascii: "Q"), UInt8(ascii: "R"), UInt8(ascii: "S"), UInt8(ascii: "T"), - UInt8(ascii: "U"), UInt8(ascii: "V"), UInt8(ascii: "W"), UInt8(ascii: "X"), - UInt8(ascii: "Y"), UInt8(ascii: "Z"), UInt8(ascii: "a"), UInt8(ascii: "b"), - UInt8(ascii: "c"), UInt8(ascii: "d"), UInt8(ascii: "e"), UInt8(ascii: "f"), - UInt8(ascii: "g"), UInt8(ascii: "h"), UInt8(ascii: "i"), UInt8(ascii: "j"), - UInt8(ascii: "k"), UInt8(ascii: "l"), UInt8(ascii: "m"), UInt8(ascii: "n"), - UInt8(ascii: "o"), UInt8(ascii: "p"), UInt8(ascii: "q"), UInt8(ascii: "r"), - UInt8(ascii: "s"), UInt8(ascii: "t"), UInt8(ascii: "u"), UInt8(ascii: "v"), - UInt8(ascii: "w"), UInt8(ascii: "x"), UInt8(ascii: "y"), UInt8(ascii: "z"), - UInt8(ascii: "0"), UInt8(ascii: "1"), UInt8(ascii: "2"), UInt8(ascii: "3"), - UInt8(ascii: "4"), UInt8(ascii: "5"), UInt8(ascii: "6"), UInt8(ascii: "7"), - UInt8(ascii: "8"), UInt8(ascii: "9"), UInt8(ascii: "+"), UInt8(ascii: "/"), - ] - - static let encodePaddingCharacter: UInt8 = UInt8(ascii: "=") - - @usableFromInline - static func encode(alphabet: [UInt8], firstByte: UInt8) -> UInt8 { - let index = firstByte >> 2 - return alphabet[Int(index)] - } - - @usableFromInline - static func encode(alphabet: [UInt8], firstByte: UInt8, secondByte: UInt8?) -> UInt8 { - var index = (firstByte & 0b00000011) << 4 - if let secondByte = secondByte { - index += (secondByte & 0b11110000) >> 4 + + // MARK: Internal + + // The base64 unicode table. + @usableFromInline + static let encodeBase64: [UInt8] = [ + UInt8(ascii: "A"), UInt8(ascii: "B"), UInt8(ascii: "C"), UInt8(ascii: "D"), + UInt8(ascii: "E"), UInt8(ascii: "F"), UInt8(ascii: "G"), UInt8(ascii: "H"), + UInt8(ascii: "I"), UInt8(ascii: "J"), UInt8(ascii: "K"), UInt8(ascii: "L"), + UInt8(ascii: "M"), UInt8(ascii: "N"), UInt8(ascii: "O"), UInt8(ascii: "P"), + UInt8(ascii: "Q"), UInt8(ascii: "R"), UInt8(ascii: "S"), UInt8(ascii: "T"), + UInt8(ascii: "U"), UInt8(ascii: "V"), UInt8(ascii: "W"), UInt8(ascii: "X"), + UInt8(ascii: "Y"), UInt8(ascii: "Z"), UInt8(ascii: "a"), UInt8(ascii: "b"), + UInt8(ascii: "c"), UInt8(ascii: "d"), UInt8(ascii: "e"), UInt8(ascii: "f"), + UInt8(ascii: "g"), UInt8(ascii: "h"), UInt8(ascii: "i"), UInt8(ascii: "j"), + UInt8(ascii: "k"), UInt8(ascii: "l"), UInt8(ascii: "m"), UInt8(ascii: "n"), + UInt8(ascii: "o"), UInt8(ascii: "p"), UInt8(ascii: "q"), UInt8(ascii: "r"), + UInt8(ascii: "s"), UInt8(ascii: "t"), UInt8(ascii: "u"), UInt8(ascii: "v"), + UInt8(ascii: "w"), UInt8(ascii: "x"), UInt8(ascii: "y"), UInt8(ascii: "z"), + UInt8(ascii: "0"), UInt8(ascii: "1"), UInt8(ascii: "2"), UInt8(ascii: "3"), + UInt8(ascii: "4"), UInt8(ascii: "5"), UInt8(ascii: "6"), UInt8(ascii: "7"), + UInt8(ascii: "8"), UInt8(ascii: "9"), UInt8(ascii: "+"), UInt8(ascii: "/"), + ] + + static let encodePaddingCharacter: UInt8 = UInt8(ascii: "=") + + @usableFromInline + static func encode(alphabet: [UInt8], firstByte: UInt8) -> UInt8 { + let index = firstByte >> 2 + return alphabet[Int(index)] } - return alphabet[Int(index)] - } - - @usableFromInline - static func encode(alphabet: [UInt8], secondByte: UInt8?, thirdByte: UInt8?) -> UInt8 { - guard let secondByte = secondByte else { - // No second byte means we are just emitting padding. - return Base64.encodePaddingCharacter + + @usableFromInline + static func encode(alphabet: [UInt8], firstByte: UInt8, secondByte: UInt8?) -> UInt8 { + var index = (firstByte & 0b00000011) << 4 + if let secondByte = secondByte { + index += (secondByte & 0b11110000) >> 4 + } + return alphabet[Int(index)] } - var index = (secondByte & 0b00001111) << 2 - if let thirdByte = thirdByte { - index += (thirdByte & 0b11000000) >> 6 + + @usableFromInline + static func encode(alphabet: [UInt8], secondByte: UInt8?, thirdByte: UInt8?) -> UInt8 { + guard let secondByte = secondByte else { + // No second byte means we are just emitting padding. + return Base64.encodePaddingCharacter + } + var index = (secondByte & 0b00001111) << 2 + if let thirdByte = thirdByte { + index += (thirdByte & 0b11000000) >> 6 + } + return alphabet[Int(index)] } - return alphabet[Int(index)] - } - - @usableFromInline - static func encode(alphabet: [UInt8], thirdByte: UInt8?) -> UInt8 { - guard let thirdByte = thirdByte else { - // No third byte means just padding. - return Base64.encodePaddingCharacter + + @usableFromInline + static func encode(alphabet: [UInt8], thirdByte: UInt8?) -> UInt8 { + guard let thirdByte = thirdByte else { + // No third byte means just padding. + return Base64.encodePaddingCharacter + } + let index = thirdByte & 0b00111111 + return alphabet[Int(index)] } - let index = thirdByte & 0b00111111 - return alphabet[Int(index)] - } } extension String { - /// This is a backport of a proposed String initializer that will allow writing directly into an uninitialized String's backing memory. - /// - /// As this API does not exist prior to 5.3 on Linux, or on older Apple platforms, we fake it out with a pointer and accept the extra copy. - @inlinable - init(backportUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { - // The buffer will store zero terminated C string - let buffer = UnsafeMutableBufferPointer.allocate(capacity: capacity + 1) - defer { - buffer.deallocate() + /// This is a backport of a proposed String initializer that will allow writing directly into an uninitialized String's backing memory. + /// + /// As this API does not exist prior to 5.3 on Linux, or on older Apple platforms, we fake it out with a pointer and accept the extra copy. + @inlinable + init( + backportUnsafeUninitializedCapacity capacity: Int, + initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int + ) rethrows { + // The buffer will store zero terminated C string + let buffer = UnsafeMutableBufferPointer.allocate(capacity: capacity + 1) + defer { + buffer.deallocate() + } + + let initializedCount = try initializer(buffer) + precondition(initializedCount <= capacity, "Overran buffer in initializer!") + // add zero termination + buffer[initializedCount] = 0 + + self = String(cString: buffer.baseAddress!) } - - let initializedCount = try initializer(buffer) - precondition(initializedCount <= capacity, "Overran buffer in initializer!") - // add zero termination - buffer[initializedCount] = 0 - - self = String(cString: buffer.baseAddress!) - } } extension String { - @inlinable - init(customUnsafeUninitializedCapacity capacity: Int, - initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int) rethrows { - if #available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *) { - try self.init(unsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) - } else { - try self.init(backportUnsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) + @inlinable + init( + customUnsafeUninitializedCapacity capacity: Int, + initializingUTF8With initializer: (_ buffer: UnsafeMutableBufferPointer) throws -> Int + ) rethrows { + if #available(macOS 11.0, iOS 14.0, tvOS 14.0, watchOS 7.0, *) { + try self.init(unsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) + } else { + try self.init(backportUnsafeUninitializedCapacity: capacity, initializingUTF8With: initializer) + } } - } } diff --git a/Sources/AsyncHTTPClient/BasicAuth.swift b/Sources/AsyncHTTPClient/BasicAuth.swift new file mode 100644 index 000000000..3e69f8277 --- /dev/null +++ b/Sources/AsyncHTTPClient/BasicAuth.swift @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2024 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Foundation +import NIOHTTP1 + +/// Generates base64 encoded username + password for http basic auth. +/// +/// - Parameters: +/// - username: the username to authenticate with +/// - password: authentication password associated with the username +/// - Returns: encoded credentials to use the Authorization: Basic http header. +func encodeBasicAuthCredentials(username: String, password: String) -> String { + var value = Data() + value.reserveCapacity(username.utf8.count + password.utf8.count + 1) + value.append(contentsOf: username.utf8) + value.append(UInt8(ascii: ":")) + value.append(contentsOf: password.utf8) + return value.base64EncodedString() +} + +extension HTTPHeaders { + /// Sets the basic auth header + mutating func setBasicAuth(username: String, password: String) { + let encoded = encodeBasicAuthCredentials(username: username, password: password) + self.replaceOrAdd(name: "Authorization", value: "Basic \(encoded)") + } +} diff --git a/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift b/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift index 58169f645..aca0ce235 100644 --- a/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift +++ b/Sources/AsyncHTTPClient/BestEffortHashableTLSConfiguration.swift @@ -27,6 +27,6 @@ struct BestEffortHashableTLSConfiguration: Hashable { } static func == (lhs: BestEffortHashableTLSConfiguration, rhs: BestEffortHashableTLSConfiguration) -> Bool { - return lhs.base.bestEffortEquals(rhs.base) + lhs.base.bestEffortEquals(rhs.base) } } diff --git a/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift b/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift index 7af13514c..5a0abdfad 100644 --- a/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift +++ b/Sources/AsyncHTTPClient/Configuration+BrowserLike.swift @@ -11,7 +11,11 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// +import NIOCore +import NIOHTTPCompression +import NIOSSL +// swift-format-ignore: DontRepeatTypeInStaticProperties extension HTTPClient.Configuration { /// The ``HTTPClient/Configuration`` for ``HTTPClient/shared`` which tries to mimic the platform's default or prevalent browser as closely as possible. /// @@ -27,14 +31,14 @@ extension HTTPClient.Configuration { /// - Linux (non-Android): Google Chrome public static var singletonConfiguration: HTTPClient.Configuration { // To start with, let's go with these values. Obtained from Firefox's config. - return HTTPClient.Configuration( + HTTPClient.Configuration( certificateVerification: .fullVerification, redirectConfiguration: .follow(max: 20, allowCycles: false), timeout: Timeout(connect: .seconds(90), read: .seconds(90)), connectionPool: .seconds(600), proxy: nil, ignoreUncleanSSLShutdown: false, - decompression: .enabled(limit: .ratio(10)), + decompression: .enabled(limit: .ratio(25)), backgroundActivityLogger: nil ) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool.swift index 8cca70750..b5b058c2e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool.swift @@ -12,13 +12,17 @@ // //===----------------------------------------------------------------------===// +import CNIOLinux +import NIOCore import NIOSSL #if canImport(Darwin) import Darwin.C #elseif canImport(Musl) import Musl -#elseif os(Linux) || os(FreeBSD) || os(Android) +#elseif canImport(Android) +import Android +#elseif os(Linux) || os(FreeBSD) import Glibc #else #error("unsupported target operating system") @@ -29,8 +33,7 @@ extension String { var ipv4Address = in_addr() var ipv6Address = in6_addr() return self.withCString { host in - inet_pton(AF_INET, host, &ipv4Address) == 1 || - inet_pton(AF_INET6, host, &ipv6Address) == 1 + inet_pton(AF_INET, host, &ipv4Address) == 1 || inet_pton(AF_INET6, host, &ipv6Address) == 1 } } } @@ -67,12 +70,13 @@ enum ConnectionPool { switch self.connectionTarget { case .ipAddress(let serialization, let addr): hostDescription = "\(serialization):\(addr.port!)" - case .domain(let domain, port: let port): + case .domain(let domain, let port): hostDescription = "\(domain):\(port)" case .unixSocket(let socketPath): hostDescription = socketPath } - return "\(self.scheme)://\(hostDescription)\(self.serverNameIndicatorOverride.map { " SNI: \($0)" } ?? "") TLS-hash: \(hash) " + return + "\(self.scheme)://\(hostDescription)\(self.serverNameIndicatorOverride.map { " SNI: \($0)" } ?? "") TLS-hash: \(hash) " } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift index fbcd4f9c0..1636fe379 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/HTTP1ProxyConnectHandler.swift @@ -42,7 +42,7 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand private var proxyEstablishedPromise: EventLoopPromise? var proxyEstablishedFuture: EventLoopFuture? { - return self.proxyEstablishedPromise?.futureResult + self.proxyEstablishedPromise?.futureResult } convenience init( @@ -53,10 +53,10 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand let targetHost: String let targetPort: Int switch target { - case .ipAddress(serialization: let serialization, address: let address): + case .ipAddress(let serialization, let address): targetHost = serialization targetPort = address.port! - case .domain(name: let domain, port: let port): + case .domain(name: let domain, let port): targetHost = domain targetPort = port case .unixSocket: @@ -70,10 +70,12 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand ) } - init(targetHost: String, - targetPort: Int, - proxyAuthorization: HTTPClient.Authorization?, - deadline: NIODeadline) { + init( + targetHost: String, + targetPort: Int, + proxyAuthorization: HTTPClient.Authorization?, + deadline: NIODeadline + ) { self.targetHost = targetHost self.targetPort = targetPort self.proxyAuthorization = proxyAuthorization @@ -135,7 +137,7 @@ final class HTTP1ProxyConnectHandler: ChannelDuplexHandler, RemovableChannelHand return } - let timeout = context.eventLoop.scheduleTask(deadline: self.deadline) { + let timeout = context.eventLoop.assumeIsolated().scheduleTask(deadline: self.deadline) { switch self.state { case .initialized: preconditionFailure("How can we have a scheduled timeout, if the connection is not even up?") diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift index 5a46f44a7..7458627fd 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/SOCKSEventsHandler.swift @@ -31,7 +31,7 @@ final class SOCKSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { private var socksEstablishedPromise: EventLoopPromise? var socksEstablishedFuture: EventLoopFuture? { - return self.socksEstablishedPromise?.futureResult + self.socksEstablishedPromise?.futureResult } private let deadline: NIODeadline @@ -99,7 +99,7 @@ final class SOCKSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { return } - let scheduled = context.eventLoop.scheduleTask(deadline: self.deadline) { + let scheduled = context.eventLoop.assumeIsolated().scheduleTask(deadline: self.deadline) { switch self.state { case .initialized, .channelActive: // close the connection, if the handshake timed out diff --git a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift index aab26fda8..d210b2747 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/ChannelHandler/TLSEventsHandler.swift @@ -31,7 +31,7 @@ final class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { private var tlsEstablishedPromise: EventLoopPromise? var tlsEstablishedFuture: EventLoopFuture? { - return self.tlsEstablishedPromise?.futureResult + self.tlsEstablishedPromise?.futureResult } private let deadline: NIODeadline? @@ -104,7 +104,7 @@ final class TLSEventsHandler: ChannelInboundHandler, RemovableChannelHandler { var scheduled: Scheduled? if let deadline = deadline { - scheduled = context.eventLoop.scheduleTask(deadline: deadline) { + scheduled = context.eventLoop.assumeIsolated().scheduleTask(deadline: deadline) { switch self.state { case .initialized, .channelActive: // close the connection, if the handshake timed out diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift index 04de8b352..191517c71 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ClientChannelHandler.swift @@ -36,7 +36,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { if let newRequest = self.request { var requestLogger = newRequest.logger requestLogger[metadataKey: "ahc-connection-id"] = self.connectionIdLoggerMetadata - requestLogger[metadataKey: "ahc-el"] = "\(self.eventLoop)" + requestLogger[metadataKey: "ahc-el"] = self.eventLoopDescription self.logger = requestLogger if let idleReadTimeout = newRequest.requestOptions.idleReadTimeout { @@ -72,11 +72,13 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { private let backgroundLogger: Logger private var logger: Logger private let eventLoop: EventLoop + private let eventLoopDescription: Logger.MetadataValue private let connectionIdLoggerMetadata: Logger.MetadataValue var onConnectionIdle: () -> Void = {} init(eventLoop: EventLoop, backgroundLogger: Logger, connectionIdLoggerMetadata: Logger.MetadataValue) { self.eventLoop = eventLoop + self.eventLoopDescription = "\(eventLoop.description)" self.backgroundLogger = backgroundLogger self.logger = backgroundLogger self.connectionIdLoggerMetadata = connectionIdLoggerMetadata @@ -98,9 +100,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // MARK: Channel Inbound Handler func channelActive(context: ChannelHandlerContext) { - self.logger.trace("Channel active", metadata: [ - "ahc-channel-writable": "\(context.channel.isWritable)", - ]) + self.logger.trace( + "Channel active", + metadata: [ + "ahc-channel-writable": "\(context.channel.isWritable)" + ] + ) let action = self.state.channelActive(isWritable: context.channel.isWritable) self.run(action, context: context) @@ -114,9 +119,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } func channelWritabilityChanged(context: ChannelHandlerContext) { - self.logger.trace("Channel writability changed", metadata: [ - "ahc-channel-writable": "\(context.channel.isWritable)", - ]) + self.logger.trace( + "Channel writability changed", + metadata: [ + "ahc-channel-writable": "\(context.channel.isWritable)" + ] + ) if let timeoutAction = self.idleWriteTimeoutStateMachine?.channelWritabilityChanged(context: context) { self.runTimeoutAction(timeoutAction, context: context) @@ -130,9 +138,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { func channelRead(context: ChannelHandlerContext, data: NIOAny) { let httpPart = self.unwrapInboundIn(data) - self.logger.trace("HTTP response part received", metadata: [ - "ahc-http-part": "\(httpPart)", - ]) + self.logger.trace( + "HTTP response part received", + metadata: [ + "ahc-http-part": "\(httpPart)" + ] + ) if let timeoutAction = self.idleReadTimeoutStateMachine?.channelRead(httpPart) { self.runTimeoutAction(timeoutAction, context: context) @@ -150,9 +161,12 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { } func errorCaught(context: ChannelHandlerContext, error: Error) { - self.logger.trace("Channel error caught", metadata: [ - "ahc-error": "\(error)", - ]) + self.logger.trace( + "Channel error caught", + metadata: [ + "ahc-error": "\(error)" + ] + ) let action = self.state.errorHappened(error) self.run(action, context: context) @@ -171,7 +185,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.runTimeoutAction(timeoutAction, context: context) } - req.willExecuteRequest(self) + req.willExecuteRequest(self.requestExecutor) let action = self.state.runNewRequest( head: req.requestHead, @@ -300,6 +314,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { let oldRequest = self.request! self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) switch finalAction { case .close: @@ -308,7 +323,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { case .sendRequestEnd(let writePromise, let shouldClose): let writePromise = writePromise ?? context.eventLoop.makePromise(of: Void.self) // We need to defer succeeding the old request to avoid ordering issues - writePromise.futureResult.hop(to: context.eventLoop).whenComplete { result in + writePromise.futureResult.hop(to: context.eventLoop).assumeIsolated().whenComplete { result in switch result { case .success: // If our final action was `sendRequestEnd`, that means we've already received @@ -339,6 +354,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { let oldRequest = self.request! self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) switch finalAction { case .close(let writePromise): @@ -380,7 +396,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { assert(self.idleReadTimeoutTimer == nil, "Expected there is no timeout timer so far.") let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -393,7 +409,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.currentIdleReadTimeoutTimerID &+= 1 let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -415,7 +431,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") let timerID = self.currentIdleWriteTimeoutTimerID - self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleWriteTimeoutTimerID == timerID else { return } let action = self.state.idleWriteTimeoutTriggered() self.run(action, context: context) @@ -427,7 +443,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.currentIdleWriteTimeoutTimerID &+= 1 let timerID = self.currentIdleWriteTimeoutTimerID - self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleWriteTimeoutTimerID == timerID else { return } let action = self.state.idleWriteTimeoutTriggered() self.run(action, context: context) @@ -445,7 +461,11 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { // MARK: Private HTTPRequestExecutor - private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { + fileprivate func writeRequestBodyPart0( + _ data: IOData, + request: HTTPExecutableRequest, + promise: EventLoopPromise? + ) { guard self.request === request, let context = self.channelContext else { // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after @@ -464,7 +484,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - private func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + fileprivate func finishRequestBodyStream0(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` promise?.fail(HTTPClientError.requestStreamCancelled) @@ -475,7 +495,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - private func demandResponseBodyStream0(_ request: HTTPExecutableRequest) { + fileprivate func demandResponseBodyStream0(_ request: HTTPExecutableRequest) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return @@ -487,7 +507,7 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { self.run(action, context: context) } - private func cancelRequest0(_ request: HTTPExecutableRequest) { + fileprivate func cancelRequest0(_ request: HTTPExecutableRequest) { guard self.request === request, let context = self.channelContext else { // See code comment in `writeRequestBodyPart0` return @@ -507,43 +527,39 @@ final class HTTP1ClientChannelHandler: ChannelDuplexHandler { @available(*, unavailable) extension HTTP1ClientChannelHandler: Sendable {} -extension HTTP1ClientChannelHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { - if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request, promise: promise) - } else { - self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request, promise: promise) +extension HTTP1ClientChannelHandler { + var requestExecutor: RequestExecutor { + RequestExecutor(self) + } + + struct RequestExecutor: HTTPRequestExecutor, Sendable { + private let loopBound: NIOLoopBound + + init(_ handler: HTTP1ClientChannelHandler) { + self.loopBound = NIOLoopBound(handler, eventLoop: handler.eventLoop) + } + + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.writeRequestBodyPart0(data, request: request, promise: promise) } } - } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { - if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request, promise: promise) - } else { - self.eventLoop.execute { - self.finishRequestBodyStream0(request, promise: promise) + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.finishRequestBodyStream0(request, promise: promise) } } - } - func demandResponseBodyStream(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.demandResponseBodyStream0(request) - } else { - self.eventLoop.execute { - self.demandResponseBodyStream0(request) + func demandResponseBodyStream(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.demandResponseBodyStream0(request) } } - } - func cancelRequest(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.cancelRequest0(request) - } else { - self.eventLoop.execute { - self.cancelRequest0(request) + func cancelRequest(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.cancelRequest0(request) } } } @@ -665,7 +681,8 @@ struct IdleWriteStateMachine { self.state = .requestEndSent return .clearIdleWriteTimeoutTimer case .waitingForWritabilityEnabled: - preconditionFailure("If the channel is not writable, we can't have sent the request end.") + self.state = .requestEndSent + return .none case .requestEndSent: return .none } @@ -688,7 +705,9 @@ struct IdleWriteStateMachine { self.state = .waitingForWritabilityEnabled return .clearIdleWriteTimeoutTimer case .waitingForWritabilityEnabled: - preconditionFailure("If the channel was writable before, then we should have been waiting for more data.") + preconditionFailure( + "If the channel was writable before, then we should have been waiting for more data." + ) case .requestEndSent: return .none } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift index ee0a78498..6f64e0407 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1Connection.swift @@ -17,9 +17,9 @@ import NIOCore import NIOHTTP1 import NIOHTTPCompression -protocol HTTP1ConnectionDelegate { - func http1ConnectionReleased(_: HTTP1Connection) - func http1ConnectionClosed(_: HTTP1Connection) +protocol HTTP1ConnectionDelegate: Sendable { + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) } final class HTTP1Connection { @@ -39,9 +39,11 @@ final class HTTP1Connection { let id: HTTPConnectionPool.Connection.ID - init(channel: Channel, - connectionID: HTTPConnectionPool.Connection.ID, - delegate: HTTP1ConnectionDelegate) { + init( + channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + delegate: HTTP1ConnectionDelegate + ) { self.channel = channel self.id = connectionID self.delegate = delegate @@ -65,32 +67,45 @@ final class HTTP1Connection { return connection } - func executeRequest(_ request: HTTPExecutableRequest) { - if self.channel.eventLoop.inEventLoop { - self.execute0(request: request) - } else { - self.channel.eventLoop.execute { - self.execute0(request: request) + var sendableView: SendableView { + SendableView(self) + } + + struct SendableView: Sendable { + private let connection: NIOLoopBound + let channel: Channel + let id: HTTPConnectionPool.Connection.ID + private var eventLoop: EventLoop { self.connection.eventLoop } + + init(_ connection: HTTP1Connection) { + self.connection = NIOLoopBound(connection, eventLoop: connection.channel.eventLoop) + self.id = connection.id + self.channel = connection.channel + } + + func executeRequest(_ request: HTTPExecutableRequest) { + self.connection.execute { + $0.execute0(request: request) } } - } - func shutdown() { - self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) - } + func shutdown() { + self.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) + } - func close(promise: EventLoopPromise?) { - return self.channel.close(mode: .all, promise: promise) - } + func close(promise: EventLoopPromise?) { + self.channel.close(mode: .all, promise: promise) + } - func close() -> EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: Void.self) - self.close(promise: promise) - return promise.futureResult + func close() -> EventLoopFuture { + let promise = self.eventLoop.makePromise(of: Void.self) + self.close(promise: promise) + return promise.futureResult + } } func taskCompleted() { - self.delegate.http1ConnectionReleased(self) + self.delegate.http1ConnectionReleased(self.id) } private func execute0(request: HTTPExecutableRequest) { @@ -98,7 +113,7 @@ final class HTTP1Connection { return request.fail(ChannelError.ioOnClosedChannel) } - self.channel.write(request, promise: nil) + self.channel.pipeline.syncOperations.write(NIOAny(request), promise: nil) } private func start(decompression: HTTPClient.Decompression, logger: Logger) throws { @@ -109,9 +124,9 @@ final class HTTP1Connection { } self.state = .active - self.channel.closeFuture.whenComplete { _ in + self.channel.closeFuture.assumeIsolated().whenComplete { _ in self.state = .closed - self.delegate.http1ConnectionClosed(self) + self.delegate.http1ConnectionClosed(self.id) } do { @@ -148,3 +163,6 @@ final class HTTP1Connection { } } } + +@available(*, unavailable) +extension HTTP1Connection: Sendable {} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift index ed4594183..2cde1df3f 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP1/HTTP1ConnectionStateMachine.swift @@ -140,7 +140,7 @@ struct HTTP1ConnectionStateMachine { self.state = .closed return .fireChannelError(error, closeConnection: false) - case .inRequest(var requestStateMachine, close: let close): + case .inRequest(var requestStateMachine, let close): return self.avoidingStateMachineCoW { state -> Action in let action = requestStateMachine.errorHappened(error) state = .inRequest(requestStateMachine, close: close) @@ -239,7 +239,9 @@ struct HTTP1ConnectionStateMachine { mutating func requestCancelled(closeConnection: Bool) -> Action { switch self.state { case .initialized: - fatalError("This event must only happen, if the connection is leased. During startup this is impossible. Invalid state: \(self.state)") + fatalError( + "This event must only happen, if the connection is leased. During startup this is impossible. Invalid state: \(self.state)" + ) case .idle: if closeConnection { @@ -249,7 +251,7 @@ struct HTTP1ConnectionStateMachine { return .wait } - case .inRequest(var requestStateMachine, close: let close): + case .inRequest(var requestStateMachine, let close): return self.avoidingStateMachineCoW { state -> Action in let action = requestStateMachine.requestCancelled() state = .inRequest(requestStateMachine, close: close || closeConnection) @@ -357,7 +359,7 @@ struct HTTP1ConnectionStateMachine { mutating func idleWriteTimeoutTriggered() -> Action { guard case .inRequest(var requestStateMachine, let close) = self.state else { - preconditionFailure("Invalid state: \(self.state)") + return .wait } return self.avoidingStateMachineCoW { state -> Action in @@ -415,12 +417,16 @@ extension HTTP1ConnectionStateMachine { } extension HTTP1ConnectionStateMachine.State { - fileprivate mutating func modify(with action: HTTPRequestStateMachine.Action) -> HTTP1ConnectionStateMachine.Action { + fileprivate mutating func modify(with action: HTTPRequestStateMachine.Action) -> HTTP1ConnectionStateMachine.Action + { switch action { case .sendRequestHead(let head, let sendEnd): return .sendRequestHead(head, sendEnd: sendEnd) case .notifyRequestHeadSendSuccessfully(let resumeRequestBodyStream, let startIdleTimer): - return .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: resumeRequestBodyStream, startIdleTimer: startIdleTimer) + return .notifyRequestHeadSendSuccessfully( + resumeRequestBodyStream: resumeRequestBodyStream, + startIdleTimer: startIdleTimer + ) case .pauseRequestBodyStream: return .pauseRequestBodyStream case .resumeRequestBodyStream: @@ -458,7 +464,7 @@ extension HTTP1ConnectionStateMachine.State { fatalError("Invalid state: \(self)") case .idle: fatalError("How can we fail a task, if we are idle") - case .inRequest(_, close: let close): + case .inRequest(_, let close): if case .close(let promise) = finalAction { self = .closing return .failRequest(error, .close(promise)) @@ -502,7 +508,7 @@ extension HTTP1ConnectionStateMachine: CustomStringConvertible { return ".initialized" case .idle: return ".idle" - case .inRequest(let request, close: let close): + case .inRequest(let request, let close): return ".inRequest(\(request), closeAfterRequest: \(close))" case .closing: return ".closing" diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift index 1520ff414..7c0197cdf 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2ClientRequestHandler.swift @@ -68,8 +68,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } func handlerAdded(context: ChannelHandlerContext) { - assert(context.eventLoop === self.eventLoop, - "The handler must be added to a channel that runs on the eventLoop it was initialized with.") + assert( + context.eventLoop === self.eventLoop, + "The handler must be added to a channel that runs on the eventLoop it was initialized with." + ) self.channelContext = context let isWritable = context.channel.isActive && context.channel.isWritable @@ -135,7 +137,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.runTimeoutAction(timeoutAction, context: context) } - request.willExecuteRequest(self) + request.willExecuteRequest(self.requestExecutor) let action = self.state.startRequest( head: request.requestHead, @@ -216,7 +218,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // that the request is neither failed nor finished yet self.request!.resumeRequestBodyStream() - case .forwardResponseHead(let head, pauseRequestBodyStream: let pauseRequestBodyStream): + case .forwardResponseHead(let head, let pauseRequestBodyStream): // We can force unwrap the request here, as we have just validated in the state machine, // that the request is neither failed nor finished yet self.request!.receiveResponseHead(head) @@ -238,6 +240,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.request!.fail(error) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) // No matter the error reason, we must always make sure the h2 stream is closed. Only // once the h2 stream is closed, it is released from the h2 multiplexer. The // HTTPRequestStateMachine may signal finalAction: .none in the error case (as this is @@ -250,6 +253,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.request!.succeedRequest(finalParts) self.request = nil self.runTimeoutAction(.clearIdleReadTimeoutTimer, context: context) + self.runTimeoutAction(.clearIdleWriteTimeoutTimer, context: context) self.runSuccessfulFinalAction(finalAction, context: context) case .failSendBodyPart(let error, let writePromise), .failSendStreamFinished(let error, let writePromise): @@ -268,7 +272,10 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.run(self.state.headSent(), context: context) } - private func runSuccessfulFinalAction(_ action: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, context: ChannelHandlerContext) { + private func runSuccessfulFinalAction( + _ action: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, + context: ChannelHandlerContext + ) { switch action { case .close, .none: // The actions returned here come from an `HTTPRequestStateMachine` that assumes http/1.1 @@ -281,7 +288,11 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } } - private func runFailedFinalAction(_ action: HTTPRequestStateMachine.Action.FinalFailedRequestAction, context: ChannelHandlerContext, error: Error) { + private func runFailedFinalAction( + _ action: HTTPRequestStateMachine.Action.FinalFailedRequestAction, + context: ChannelHandlerContext, + error: Error + ) { // We must close the http2 stream after the request has finished. Since the request failed, // we have no idea what the h2 streams state was. To be on the save side, we explicitly close // the h2 stream. This will break a reference cycle in HTTP2Connection. @@ -302,7 +313,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { assert(self.idleReadTimeoutTimer == nil, "Expected there is no timeout timer so far.") let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -315,7 +326,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.currentIdleReadTimeoutTimerID &+= 1 let timerID = self.currentIdleReadTimeoutTimerID - self.idleReadTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleReadTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleReadTimeoutTimerID == timerID else { return } let action = self.state.idleReadTimeoutTriggered() self.run(action, context: context) @@ -338,7 +349,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { assert(self.idleWriteTimeoutTimer == nil, "Expected there is no timeout timer so far.") let timerID = self.currentIdleWriteTimeoutTimerID - self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleWriteTimeoutTimerID == timerID else { return } let action = self.state.idleWriteTimeoutTriggered() self.run(action, context: context) @@ -350,7 +361,7 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { self.currentIdleWriteTimeoutTimerID &+= 1 let timerID = self.currentIdleWriteTimeoutTimerID - self.idleWriteTimeoutTimer = self.eventLoop.scheduleTask(in: timeAmount) { + self.idleWriteTimeoutTimer = self.eventLoop.assumeIsolated().scheduleTask(in: timeAmount) { guard self.currentIdleWriteTimeoutTimerID == timerID else { return } let action = self.state.idleWriteTimeoutTriggered() self.run(action, context: context) @@ -368,7 +379,8 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { // MARK: Private HTTPRequestExecutor - private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { + private func writeRequestBodyPart0(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) + { guard self.request === request, let context = self.channelContext else { // Because the HTTPExecutableRequest may run in a different thread to our eventLoop, // calls from the HTTPExecutableRequest to our ChannelHandler may arrive here after @@ -422,43 +434,42 @@ final class HTTP2ClientRequestHandler: ChannelDuplexHandler { } } -extension HTTP2ClientRequestHandler: HTTPRequestExecutor { - func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { - if self.eventLoop.inEventLoop { - self.writeRequestBodyPart0(data, request: request, promise: promise) - } else { - self.eventLoop.execute { - self.writeRequestBodyPart0(data, request: request, promise: promise) +@available(*, unavailable) +extension HTTP2ClientRequestHandler: Sendable {} + +extension HTTP2ClientRequestHandler { + var requestExecutor: RequestExecutor { + RequestExecutor(self) + } + + struct RequestExecutor: HTTPRequestExecutor, Sendable { + private let loopBound: NIOLoopBound + + init(_ handler: HTTP2ClientRequestHandler) { + self.loopBound = NIOLoopBound(handler, eventLoop: handler.eventLoop) + } + + func writeRequestBodyPart(_ data: IOData, request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.writeRequestBodyPart0(data, request: request, promise: promise) } } - } - func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { - if self.eventLoop.inEventLoop { - self.finishRequestBodyStream0(request, promise: promise) - } else { - self.eventLoop.execute { - self.finishRequestBodyStream0(request, promise: promise) + func finishRequestBodyStream(_ request: HTTPExecutableRequest, promise: EventLoopPromise?) { + self.loopBound.execute { + $0.finishRequestBodyStream0(request, promise: promise) } } - } - func demandResponseBodyStream(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.demandResponseBodyStream0(request) - } else { - self.eventLoop.execute { - self.demandResponseBodyStream0(request) + func demandResponseBodyStream(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.demandResponseBodyStream0(request) } } - } - func cancelRequest(_ request: HTTPExecutableRequest) { - if self.eventLoop.inEventLoop { - self.cancelRequest0(request) - } else { - self.eventLoop.execute { - self.cancelRequest0(request) + func cancelRequest(_ request: HTTPExecutableRequest) { + self.loopBound.execute { + $0.cancelRequest0(request) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift index 2c3c3cc0a..1c24554e2 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2Connection.swift @@ -17,11 +17,11 @@ import NIOCore import NIOHTTP2 import NIOHTTPCompression -protocol HTTP2ConnectionDelegate { - func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) - func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) - func http2ConnectionGoAwayReceived(_: HTTP2Connection) - func http2ConnectionClosed(_: HTTP2Connection) +protocol HTTP2ConnectionDelegate: Sendable { + func http2Connection(_: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) + func http2ConnectionStreamClosed(_: HTTPConnectionPool.Connection.ID, availableStreams: Int) + func http2ConnectionGoAwayReceived(_: HTTPConnectionPool.Connection.ID) + func http2ConnectionClosed(_: HTTPConnectionPool.Connection.ID) } struct HTTP2PushNotSupportedError: Error {} @@ -29,10 +29,15 @@ struct HTTP2PushNotSupportedError: Error {} struct HTTP2ReceivedGoAwayBeforeSettingsError: Error {} final class HTTP2Connection { + internal static let defaultSettings = nioDefaultSettings + [HTTP2Setting(parameter: .enablePush, value: 0)] + let channel: Channel let multiplexer: HTTP2StreamMultiplexer let logger: Logger + /// A method with access to the stream channel that is called when creating the stream. + let streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + /// the connection pool that created the connection let delegate: HTTP2ConnectionDelegate @@ -87,12 +92,15 @@ final class HTTP2Connection { self.channel.closeFuture } - init(channel: Channel, - connectionID: HTTPConnectionPool.Connection.ID, - decompression: HTTPClient.Decompression, - maximumConnectionUses: Int?, - delegate: HTTP2ConnectionDelegate, - logger: Logger) { + init( + channel: Channel, + connectionID: HTTPConnectionPool.Connection.ID, + decompression: HTTPClient.Decompression, + maximumConnectionUses: Int?, + delegate: HTTP2ConnectionDelegate, + logger: Logger, + streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil + ) { self.channel = channel self.id = connectionID self.decompression = decompression @@ -101,7 +109,7 @@ final class HTTP2Connection { self.multiplexer = HTTP2StreamMultiplexer( mode: .client, channel: channel, - targetWindowSize: 8 * 1024 * 1024, // 8mb + targetWindowSize: 8 * 1024 * 1024, // 8mb outboundBufferSizeHighWatermark: 8196, outboundBufferSizeLowWatermark: 4092, inboundStreamInitializer: { channel -> EventLoopFuture in @@ -110,6 +118,7 @@ final class HTTP2Connection { ) self.delegate = delegate self.state = .initialized + self.streamChannelDebugInitializer = streamChannelDebugInitializer } deinit { @@ -124,49 +133,72 @@ final class HTTP2Connection { delegate: HTTP2ConnectionDelegate, decompression: HTTPClient.Decompression, maximumConnectionUses: Int?, - logger: Logger - ) -> EventLoopFuture<(HTTP2Connection, Int)> { + logger: Logger, + streamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil + ) -> EventLoopFuture<(HTTP2Connection, Int)>.Isolated { let connection = HTTP2Connection( channel: channel, connectionID: connectionID, decompression: decompression, maximumConnectionUses: maximumConnectionUses, delegate: delegate, - logger: logger + logger: logger, + streamChannelDebugInitializer: streamChannelDebugInitializer ) - return connection._start0().map { maxStreams in (connection, maxStreams) } + + return connection._start0().assumeIsolated().map { maxStreams in + (connection, maxStreams) + } + } + + var sendableView: SendableView { + SendableView(self) } - func executeRequest(_ request: HTTPExecutableRequest) { - if self.channel.eventLoop.inEventLoop { - self.executeRequest0(request) - } else { - self.channel.eventLoop.execute { - self.executeRequest0(request) + struct SendableView: Sendable { + private let connection: NIOLoopBound + let id: HTTPConnectionPool.Connection.ID + let channel: Channel + + var eventLoop: EventLoop { + self.connection.eventLoop + } + + var closeFuture: EventLoopFuture { + self.channel.closeFuture + } + + func __forTesting_getStreamChannels() -> [Channel] { + self.connection.value.__forTesting_getStreamChannels() + } + + init(_ connection: HTTP2Connection) { + self.connection = NIOLoopBound(connection, eventLoop: connection.channel.eventLoop) + self.id = connection.id + self.channel = connection.channel + } + + func executeRequest(_ request: HTTPExecutableRequest) { + self.connection.execute { + $0.executeRequest0(request) } } - } - /// shuts down the connection by cancelling all running tasks and closing the connection once - /// all child streams/channels are closed. - func shutdown() { - if self.channel.eventLoop.inEventLoop { - self.shutdown0() - } else { - self.channel.eventLoop.execute { - self.shutdown0() + func shutdown() { + self.connection.execute { + $0.shutdown0() } } - } - func close(promise: EventLoopPromise?) { - return self.channel.close(mode: .all, promise: promise) - } + func close(promise: EventLoopPromise?) { + self.channel.close(mode: .all, promise: promise) + } - func close() -> EventLoopFuture { - let promise = self.channel.eventLoop.makePromise(of: Void.self) - self.close(promise: promise) - return promise.futureResult + func close() -> EventLoopFuture { + let promise = self.eventLoop.makePromise(of: Void.self) + self.close(promise: promise) + return promise.futureResult + } } func _start0() -> EventLoopFuture { @@ -175,7 +207,7 @@ final class HTTP2Connection { let readyToAcceptConnectionsPromise = self.channel.eventLoop.makePromise(of: Int.self) self.state = .starting(readyToAcceptConnectionsPromise) - self.channel.closeFuture.whenComplete { _ in + self.channel.closeFuture.assumeIsolated().whenComplete { _ in switch self.state { case .initialized, .closed: preconditionFailure("invalid state \(self.state)") @@ -184,7 +216,7 @@ final class HTTP2Connection { readyToAcceptConnectionsPromise.fail(HTTPClientError.remoteConnectionClosed) case .active, .closing: self.state = .closed - self.delegate.http2ConnectionClosed(self) + self.delegate.http2ConnectionClosed(self.id) } } @@ -196,8 +228,12 @@ final class HTTP2Connection { // can be scheduled on this connection. let sync = self.channel.pipeline.syncOperations - let http2Handler = NIOHTTP2Handler(mode: .client, initialSettings: nioDefaultSettings) - let idleHandler = HTTP2IdleHandler(delegate: self, logger: self.logger, maximumConnectionUses: self.maximumConnectionUses) + let http2Handler = NIOHTTP2Handler(mode: .client, initialSettings: Self.defaultSettings) + let idleHandler = HTTP2IdleHandler( + delegate: self, + logger: self.logger, + maximumConnectionUses: self.maximumConnectionUses + ) try sync.addHandler(http2Handler, position: .last) try sync.addHandler(idleHandler, position: .last) @@ -219,12 +255,18 @@ final class HTTP2Connection { case .active: let createStreamChannelPromise = self.channel.eventLoop.makePromise(of: Channel.self) - self.multiplexer.createStreamChannel(promise: createStreamChannelPromise) { channel -> EventLoopFuture in + let loopBoundSelf = NIOLoopBound(self, eventLoop: self.channel.eventLoop) + + self.multiplexer.createStreamChannel( + promise: createStreamChannelPromise + ) { [streamChannelDebugInitializer] channel -> EventLoopFuture in + let connection = loopBoundSelf.value + do { // the connection may have been asked to shutdown while we created the child. in // this // channel. - guard case .active = self.state else { + guard case .active = connection.state else { throw HTTPClientError.cancelled } @@ -233,7 +275,7 @@ final class HTTP2Connection { let translate = HTTP2FramePayloadToHTTP1ClientCodec(httpProtocol: .https) try channel.pipeline.syncOperations.addHandler(translate) - if case .enabled(let limit) = self.decompression { + if case .enabled(let limit) = connection.decompression { let decompressHandler = NIOHTTPResponseDecompressor(limit: limit) try channel.pipeline.syncOperations.addHandler(decompressHandler) } @@ -245,13 +287,19 @@ final class HTTP2Connection { // request to it. In case of an error, we are sure that the channel was added // before. let box = ChannelBox(channel) - self.openStreams.insert(box) - channel.closeFuture.whenComplete { _ in - self.openStreams.remove(box) + connection.openStreams.insert(box) + channel.closeFuture.assumeIsolated().whenComplete { _ in + connection.openStreams.remove(box) } - channel.write(request, promise: nil) - return channel.eventLoop.makeSucceededVoidFuture() + if let streamChannelDebugInitializer = streamChannelDebugInitializer { + return streamChannelDebugInitializer(channel).map { _ in + channel.write(request, promise: nil) + } + } else { + channel.pipeline.syncOperations.write(NIOAny(request), promise: nil) + return channel.eventLoop.makeSucceededVoidFuture() + } } catch { return channel.eventLoop.makeFailedFuture(error) } @@ -276,7 +324,7 @@ final class HTTP2Connection { self.state = .closing // inform all open streams, that the currently running request should be cancelled. - self.openStreams.forEach { box in + for box in self.openStreams { box.channel.triggerUserOutboundEvent(HTTPConnectionEvent.shutdownRequested, promise: nil) } @@ -313,7 +361,7 @@ extension HTTP2Connection: HTTP2IdleHandlerDelegate { case .active: self.state = .active(maxStreams: maxStreams) - self.delegate.http2Connection(self, newMaxStreamSetting: maxStreams) + self.delegate.http2Connection(self.id, newMaxStreamSetting: maxStreams) case .closing, .closed: // ignore. we only wait for all connections to be closed anyway. @@ -334,7 +382,7 @@ extension HTTP2Connection: HTTP2IdleHandlerDelegate { case .active: self.state = .closing - self.delegate.http2ConnectionGoAwayReceived(self) + self.delegate.http2ConnectionGoAwayReceived(self.id) case .closing, .closed: // we are already closing. Nothing new @@ -345,6 +393,9 @@ extension HTTP2Connection: HTTP2IdleHandlerDelegate { func http2StreamClosed(availableStreams: Int) { self.channel.eventLoop.assertInEventLoop() - self.delegate.http2ConnectionStreamClosed(self, availableStreams: availableStreams) + self.delegate.http2ConnectionStreamClosed(self.id, availableStreams: availableStreams) } } + +@available(*, unavailable) +extension HTTP2Connection: Sendable {} diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift index 06458cb7e..64a151489 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTP2/HTTP2IdleHandler.swift @@ -184,9 +184,15 @@ extension HTTP2IdleHandler { self.state = .active(openStreams: 0, maxStreams: maxStreams, remainingUses: remainingUses) return .notifyConnectionNewMaxStreamsSettings(maxStreams) - case .active(openStreams: let openStreams, maxStreams: let maxStreams, remainingUses: let remainingUses): - if let newMaxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value, newMaxStreams != maxStreams { - self.state = .active(openStreams: openStreams, maxStreams: newMaxStreams, remainingUses: remainingUses) + case .active(let openStreams, let maxStreams, let remainingUses): + if let newMaxStreams = settings.last(where: { $0.parameter == .maxConcurrentStreams })?.value, + newMaxStreams != maxStreams + { + self.state = .active( + openStreams: openStreams, + maxStreams: newMaxStreams, + remainingUses: remainingUses + ) return .notifyConnectionNewMaxStreamsSettings(newMaxStreams) } return .nothing diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift index 1461a6620..c896791cf 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Factory.swift @@ -20,7 +20,9 @@ import NIOPosix import NIOSOCKS import NIOSSL import NIOTLS + #if canImport(Network) +import Network import NIOTransportServices #endif @@ -31,21 +33,24 @@ extension HTTPConnectionPool { let tlsConfiguration: TLSConfiguration let sslContextCache: SSLContextCache - init(key: ConnectionPool.Key, - tlsConfiguration: TLSConfiguration?, - clientConfiguration: HTTPClient.Configuration, - sslContextCache: SSLContextCache) { + init( + key: ConnectionPool.Key, + tlsConfiguration: TLSConfiguration?, + clientConfiguration: HTTPClient.Configuration, + sslContextCache: SSLContextCache + ) { self.key = key self.clientConfiguration = clientConfiguration self.sslContextCache = sslContextCache - self.tlsConfiguration = tlsConfiguration ?? clientConfiguration.tlsConfiguration ?? .makeClientConfiguration() + self.tlsConfiguration = + tlsConfiguration ?? clientConfiguration.tlsConfiguration ?? .makeClientConfiguration() } } } -protocol HTTPConnectionRequester { - func http1ConnectionCreated(_: HTTP1Connection) - func http2ConnectionCreated(_: HTTP2Connection, maximumStreams: Int) +protocol HTTPConnectionRequester: Sendable { + func http1ConnectionCreated(_: HTTP1Connection.SendableView) + func http2ConnectionCreated(_: HTTP2Connection.SendableView, maximumStreams: Int) func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Error) func waitingForConnectivity(_: HTTPConnectionPool.Connection.ID, error: Error) } @@ -63,7 +68,13 @@ extension HTTPConnectionPool.ConnectionFactory { var logger = logger logger[metadataKey: "ahc-connection-id"] = "\(connectionID)" - self.makeChannel(requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, logger: logger).whenComplete { result in + self.makeChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger + ).whenComplete { [logger] result in switch result { case .success(.http1_1(let channel)): do { @@ -74,7 +85,21 @@ extension HTTPConnectionPool.ConnectionFactory { decompression: self.clientConfiguration.decompression, logger: logger ) - requester.http1ConnectionCreated(connection) + + if let connectionDebugInitializer = self.clientConfiguration.http1_1ConnectionDebugInitializer { + connectionDebugInitializer(channel).hop( + to: eventLoop + ).assumeIsolated().whenComplete { debugInitializerResult in + switch debugInitializerResult { + case .success: + requester.http1ConnectionCreated(connection.sendableView) + case .failure(let error): + requester.failedToCreateHTTPConnection(connectionID, error: error) + } + } + } else { + requester.http1ConnectionCreated(connection.sendableView) + } } catch { requester.failedToCreateHTTPConnection(connectionID, error: error) } @@ -85,11 +110,34 @@ extension HTTPConnectionPool.ConnectionFactory { delegate: http2ConnectionDelegate, decompression: self.clientConfiguration.decompression, maximumConnectionUses: self.clientConfiguration.maximumUsesPerConnection, - logger: logger + logger: logger, + streamChannelDebugInitializer: + self.clientConfiguration.http2StreamChannelDebugInitializer ).whenComplete { result in switch result { case .success((let connection, let maximumStreams)): - requester.http2ConnectionCreated(connection, maximumStreams: maximumStreams) + if let connectionDebugInitializer = self.clientConfiguration.http2ConnectionDebugInitializer { + connectionDebugInitializer(channel).hop(to: eventLoop).assumeIsolated().whenComplete { + debugInitializerResult in + switch debugInitializerResult { + case .success: + requester.http2ConnectionCreated( + connection.sendableView, + maximumStreams: maximumStreams + ) + case .failure(let error): + requester.failedToCreateHTTPConnection( + connectionID, + error: error + ) + } + } + } else { + requester.http2ConnectionCreated( + connection.sendableView, + maximumStreams: maximumStreams + ) + } case .failure(let error): requester.failedToCreateHTTPConnection(connectionID, error: error) } @@ -137,7 +185,13 @@ extension HTTPConnectionPool.ConnectionFactory { ) } } else { - channelFuture = self.makeNonProxiedChannel(requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, logger: logger) + channelFuture = self.makeNonProxiedChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger + ) } // let's map `ChannelError.connectTimeout` into a `HTTPClientError.connectTimeout` @@ -160,10 +214,22 @@ extension HTTPConnectionPool.ConnectionFactory { ) -> EventLoopFuture { switch self.key.scheme { case .http, .httpUnix, .unix: - return self.makePlainChannel(requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop).map { .http1_1($0) } + return self.makePlainChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ).map { .http1_1($0) } case .https, .httpsUnix: - return self.makeTLSChannel(requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop, logger: logger).flatMapThrowing { - channel, negotiated in + return self.makeTLSChannel( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop, + logger: logger + ).flatMapThrowing { + channel, + negotiated in try self.matchALPNToHTTPVersion(negotiated, channel: channel) } @@ -177,7 +243,12 @@ extension HTTPConnectionPool.ConnectionFactory { eventLoop: EventLoop ) -> EventLoopFuture { precondition(!self.key.scheme.usesTLS, "Unexpected scheme") - return self.makePlainBootstrap(requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop).connect(target: self.key.connectionTarget) + return self.makePlainBootstrap( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ).connect(target: self.key.connectionTarget) } private func makeHTTPProxyChannel( @@ -191,7 +262,12 @@ extension HTTPConnectionPool.ConnectionFactory { // A proxy connection starts with a plain text connection to the proxy server. After // the connection has been established with the proxy server, the connection might be // upgraded to TLS before we send our first request. - let bootstrap = self.makePlainBootstrap(requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop) + let bootstrap = self.makePlainBootstrap( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ) return bootstrap.connect(host: proxy.host, port: proxy.port).flatMap { channel in let encoder = HTTPRequestEncoder() let decoder = ByteToMessageHandler(HTTPResponseDecoder(leftOverBytesStrategy: .dropBytes)) @@ -211,15 +287,15 @@ extension HTTPConnectionPool.ConnectionFactory { // The proxyEstablishedFuture is set as soon as the HTTP1ProxyConnectHandler is in a // pipeline. It is created in HTTP1ProxyConnectHandler's handlerAdded method. - return proxyHandler.proxyEstablishedFuture!.flatMap { - channel.pipeline.removeHandler(proxyHandler).flatMap { - channel.pipeline.removeHandler(decoder).flatMap { - channel.pipeline.removeHandler(encoder) - } - } + return proxyHandler.proxyEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(proxyHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(decoder).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(encoder) + }.nonisolated() + }.nonisolated() }.flatMap { self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) - } + }.nonisolated() } } @@ -234,7 +310,12 @@ extension HTTPConnectionPool.ConnectionFactory { // A proxy connection starts with a plain text connection to the proxy server. After // the connection has been established with the proxy server, the connection might be // upgraded to TLS before we send our first request. - let bootstrap = self.makePlainBootstrap(requester: requester, connectionID: connectionID, deadline: deadline, eventLoop: eventLoop) + let bootstrap = self.makePlainBootstrap( + requester: requester, + connectionID: connectionID, + deadline: deadline, + eventLoop: eventLoop + ) return bootstrap.connect(host: proxy.host, port: proxy.port).flatMap { channel in let socksConnectHandler = SOCKSClientHandler(targetAddress: SOCKSAddress(self.key.connectionTarget)) let socksEventHandler = SOCKSEventsHandler(deadline: deadline) @@ -248,13 +329,13 @@ extension HTTPConnectionPool.ConnectionFactory { // The socksEstablishedFuture is set as soon as the SOCKSEventsHandler is in a // pipeline. It is created in SOCKSEventsHandler's handlerAdded method. - return socksEventHandler.socksEstablishedFuture!.flatMap { - channel.pipeline.removeHandler(socksEventHandler).flatMap { - channel.pipeline.removeHandler(socksConnectHandler) - } + return socksEventHandler.socksEstablishedFuture!.assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksEventHandler).assumeIsolated().flatMap { + channel.pipeline.syncOperations.removeHandler(socksConnectHandler) + }.nonisolated() }.flatMap { self.setupTLSInProxyConnectionIfNeeded(channel, deadline: deadline, logger: logger) - } + }.nonisolated() } } @@ -280,7 +361,6 @@ extension HTTPConnectionPool.ConnectionFactory { case .http1Only: tlsConfig.applicationProtocols = ["http/1.1"] } - let tlsEventHandler = TLSEventsHandler(deadline: deadline) let sslServerHostname = self.key.serverNameIndicator let sslContextFuture = self.sslContextCache.sslContext( @@ -296,6 +376,7 @@ extension HTTPConnectionPool.ConnectionFactory { serverHostname: sslServerHostname ) try channel.pipeline.syncOperations.addHandler(sslHandler) + let tlsEventHandler = TLSEventsHandler(deadline: deadline) try channel.pipeline.syncOperations.addHandler(tlsEventHandler) // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a @@ -305,8 +386,14 @@ extension HTTPConnectionPool.ConnectionFactory { return channel.eventLoop.makeFailedFuture(error) } }.flatMap { negotiated -> EventLoopFuture in - channel.pipeline.removeHandler(tlsEventHandler).flatMapThrowing { - try self.matchALPNToHTTPVersion(negotiated, channel: channel) + do { + let sync = channel.pipeline.syncOperations + let context = try sync.context(handlerType: TLSEventsHandler.self) + return sync.removeHandler(context: context).flatMapThrowing { + try self.matchALPNToHTTPVersion(negotiated, channel: channel) + } + } catch { + return channel.eventLoop.makeFailedFuture(error) } } } @@ -319,14 +406,26 @@ extension HTTPConnectionPool.ConnectionFactory { eventLoop: EventLoop ) -> NIOClientTCPBootstrapProtocol { #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { - return tsBootstrap - .channelOption(NIOTSChannelOptions.waitForActivity, value: self.clientConfiguration.networkFrameworkWaitForConnectivity) + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), + let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) + { + return + tsBootstrap + .channelOption( + NIOTSChannelOptions.waitForActivity, + value: self.clientConfiguration.networkFrameworkWaitForConnectivity + ) + .channelOption( + NIOTSChannelOptions.multipathServiceType, + value: self.clientConfiguration.enableMultipath ? .handover : .disabled + ) .connectTimeout(deadline - NIODeadline.now()) .channelInitializer { channel in do { try channel.pipeline.syncOperations.addHandler(HTTPClient.NWErrorHandler()) - try channel.pipeline.syncOperations.addHandler(NWWaitingHandler(requester: requester, connectionID: connectionID)) + try channel.pipeline.syncOperations.addHandler( + NWWaitingHandler(requester: requester, connectionID: connectionID) + ) return channel.eventLoop.makeSucceededVoidFuture() } catch { return channel.eventLoop.makeFailedFuture(error) @@ -336,8 +435,10 @@ extension HTTPConnectionPool.ConnectionFactory { #endif if let nioBootstrap = ClientBootstrap(validatingGroup: eventLoop) { - return nioBootstrap + return + nioBootstrap .connectTimeout(deadline - NIODeadline.now()) + .enableMPTCP(clientConfiguration.enableMultipath) } preconditionFailure("No matching bootstrap found") @@ -360,7 +461,7 @@ extension HTTPConnectionPool.ConnectionFactory { ) var channelFuture = bootstrapFuture.flatMap { bootstrap -> EventLoopFuture in - return bootstrap.connect(target: self.key.connectionTarget) + bootstrap.connect(target: self.key.connectionTarget) }.flatMap { channel -> EventLoopFuture<(Channel, String?)> in do { // if the channel is closed before flatMap is executed, all ChannelHandler are removed @@ -369,11 +470,14 @@ extension HTTPConnectionPool.ConnectionFactory { // The tlsEstablishedFuture is set as soon as the TLSEventsHandler is in a // pipeline. It is created in TLSEventsHandler's handlerAdded method. - return tlsEventHandler.tlsEstablishedFuture!.flatMap { negotiated in - channel.pipeline.removeHandler(tlsEventHandler).map { (channel, negotiated) } - } + return tlsEventHandler.tlsEstablishedFuture!.assumeIsolated().flatMap { negotiated in + channel.pipeline.syncOperations.removeHandler(tlsEventHandler).map { (channel, negotiated) } + }.nonisolated() } catch { - assert(channel.isActive == false, "if the channel is still active then TLSEventsHandler must be present but got error \(error)") + assert( + channel.isActive == false, + "if the channel is still active then TLSEventsHandler must be present but got error \(error)" + ) return channel.eventLoop.makeFailedFuture(HTTPClientError.remoteConnectionClosed) } } @@ -408,19 +512,31 @@ extension HTTPConnectionPool.ConnectionFactory { } #if canImport(Network) - if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), let tsBootstrap = NIOTSConnectionBootstrap(validatingGroup: eventLoop) { + if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), eventLoop is QoSEventLoop { // create NIOClientTCPBootstrap with NIOTS TLS provider - let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions(on: eventLoop, serverNameIndicatorOverride: key.serverNameIndicatorOverride).map { + let bootstrapFuture = tlsConfig.getNWProtocolTLSOptions( + on: eventLoop, + serverNameIndicatorOverride: key.serverNameIndicatorOverride + ).map { options -> NIOClientTCPBootstrapProtocol in - tsBootstrap - .channelOption(NIOTSChannelOptions.waitForActivity, value: self.clientConfiguration.networkFrameworkWaitForConnectivity) + NIOTSConnectionBootstrap(group: eventLoop) // validated above + .channelOption( + NIOTSChannelOptions.waitForActivity, + value: self.clientConfiguration.networkFrameworkWaitForConnectivity + ) + .channelOption( + NIOTSChannelOptions.multipathServiceType, + value: self.clientConfiguration.enableMultipath ? .handover : .disabled + ) .connectTimeout(deadline - NIODeadline.now()) .tlsOptions(options) .channelInitializer { channel in do { try channel.pipeline.syncOperations.addHandler(HTTPClient.NWErrorHandler()) - try channel.pipeline.syncOperations.addHandler(NWWaitingHandler(requester: requester, connectionID: connectionID)) + try channel.pipeline.syncOperations.addHandler( + NWWaitingHandler(requester: requester, connectionID: connectionID) + ) // we don't need to set a TLS deadline for NIOTS connections, since the // TLS handshake is part of the TS connection bootstrap. If the TLS // handshake times out the complete connection creation will be failed. @@ -441,28 +557,29 @@ extension HTTPConnectionPool.ConnectionFactory { logger: logger ) - let bootstrap = ClientBootstrap(group: eventLoop) - .connectTimeout(deadline - NIODeadline.now()) - .channelInitializer { channel in - sslContextFuture.flatMap { sslContext -> EventLoopFuture in - do { - let sync = channel.pipeline.syncOperations - let sslHandler = try NIOSSLClientHandler( - context: sslContext, - serverHostname: self.key.serverNameIndicator - ) - let tlsEventHandler = TLSEventsHandler(deadline: deadline) - - try sync.addHandler(sslHandler) - try sync.addHandler(tlsEventHandler) - return channel.eventLoop.makeSucceededVoidFuture() - } catch { - return channel.eventLoop.makeFailedFuture(error) + return eventLoop.submit { + ClientBootstrap(group: eventLoop) + .connectTimeout(deadline - NIODeadline.now()) + .enableMPTCP(clientConfiguration.enableMultipath) + .channelInitializer { channel in + sslContextFuture.flatMap { sslContext -> EventLoopFuture in + do { + let sync = channel.pipeline.syncOperations + let sslHandler = try NIOSSLClientHandler( + context: sslContext, + serverHostname: self.key.serverNameIndicator + ) + let tlsEventHandler = TLSEventsHandler(deadline: deadline) + + try sync.addHandler(sslHandler) + try sync.addHandler(tlsEventHandler) + return channel.eventLoop.makeSucceededVoidFuture() + } catch { + return channel.eventLoop.makeFailedFuture(error) + } } } - } - - return eventLoop.makeSucceededFuture(bootstrap) + } } private func matchALPNToHTTPVersion(_ negotiated: String?, channel: Channel) throws -> NegotiatedProtocol { diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift index f5a0540cf..3fdf93752 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool+Manager.swift @@ -39,9 +39,11 @@ extension HTTPConnectionPool { private let sslContextCache = SSLContextCache() - init(eventLoopGroup: EventLoopGroup, - configuration: HTTPClient.Configuration, - backgroundActivityLogger logger: Logger) { + init( + eventLoopGroup: EventLoopGroup, + configuration: HTTPClient.Configuration, + backgroundActivityLogger logger: Logger + ) { self.eventLoopGroup = eventLoopGroup self.configuration = configuration self.logger = logger @@ -118,7 +120,7 @@ extension HTTPConnectionPool { promise?.succeed(false) case .shutdown(let pools): - pools.values.forEach { pool in + for pool in pools.values { pool.shutdown() } } @@ -140,7 +142,9 @@ extension HTTPConnectionPool.Manager: HTTPConnectionPoolDelegate { case .shuttingDown(let promise, let soFarUnclean): guard self._pools.removeValue(forKey: pool.key) === pool else { - preconditionFailure("Expected that the pool was created by this manager and is known for this reason.") + preconditionFailure( + "Expected that the pool was created by this manager and is known for this reason." + ) } if self._pools.isEmpty { @@ -154,7 +158,7 @@ extension HTTPConnectionPool.Manager: HTTPConnectionPoolDelegate { } switch closeAction { - case .close(let promise, unclean: let unclean): + case .close(let promise, let unclean): promise?.succeed(unclean) case .wait: break @@ -173,7 +177,7 @@ extension HTTPConnectionPool.Connection.ID { } func next() -> Int { - return self.atomic.loadThenWrappingIncrement(ordering: .relaxed) + self.atomic.loadThenWrappingIncrement(ordering: .relaxed) } } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift index eac4cc21f..251224ac0 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPConnectionPool.swift @@ -21,7 +21,10 @@ protocol HTTPConnectionPoolDelegate { func connectionPoolDidShutdown(_ pool: HTTPConnectionPool, unclean: Bool) } -final class HTTPConnectionPool { +final class HTTPConnectionPool: + // TODO: Refactor to use `NIOLockedValueBox` which will allow this to be checked + @unchecked Sendable +{ private let stateLock = NIOLock() private var _state: StateMachine /// The connection idle timeout timers. Protected by the stateLock @@ -44,14 +47,16 @@ final class HTTPConnectionPool { let delegate: HTTPConnectionPoolDelegate - init(eventLoopGroup: EventLoopGroup, - sslContextCache: SSLContextCache, - tlsConfiguration: TLSConfiguration?, - clientConfiguration: HTTPClient.Configuration, - key: ConnectionPool.Key, - delegate: HTTPConnectionPoolDelegate, - idGenerator: Connection.ID.Generator, - backgroundActivityLogger logger: Logger) { + init( + eventLoopGroup: EventLoopGroup, + sslContextCache: SSLContextCache, + tlsConfiguration: TLSConfiguration?, + clientConfiguration: HTTPClient.Configuration, + key: ConnectionPool.Key, + delegate: HTTPConnectionPoolDelegate, + idGenerator: Connection.ID.Generator, + backgroundActivityLogger logger: Logger + ) { self.eventLoopGroup = eventLoopGroup self.connectionFactory = ConnectionFactory( key: key, @@ -70,8 +75,10 @@ final class HTTPConnectionPool { self._state = StateMachine( idGenerator: idGenerator, - maximumConcurrentHTTP1Connections: clientConfiguration.connectionPool.concurrentHTTP1ConnectionsPerHostSoftLimit, + maximumConcurrentHTTP1Connections: clientConfiguration.connectionPool + .concurrentHTTP1ConnectionsPerHostSoftLimit, retryConnectionEstablishment: clientConfiguration.connectionPool.retryConnectionEstablishment, + preferHTTP1: clientConfiguration.httpVersion == .http1Only, maximumConnectionUses: clientConfiguration.maximumUsesPerConnection ) } @@ -149,7 +156,7 @@ final class HTTPConnectionPool { self.unlocked = Unlocked(connection: .none, request: .none) switch stateMachineAction.request { - case .executeRequest(let request, let connection, cancelTimeout: let cancelTimeout): + case .executeRequest(let request, let connection, let cancelTimeout): if cancelTimeout { self.locked.request = .cancelRequestTimeout(request.id) } @@ -157,7 +164,7 @@ final class HTTPConnectionPool { case .executeRequestsAndCancelTimeouts(let requests, let connection): self.locked.request = .cancelRequestTimeouts(requests) self.unlocked.request = .executeRequests(requests, connection) - case .failRequest(let request, let error, cancelTimeout: let cancelTimeout): + case .failRequest(let request, let error, let cancelTimeout): if cancelTimeout { self.locked.request = .cancelRequestTimeout(request.id) } @@ -174,15 +181,15 @@ final class HTTPConnectionPool { switch stateMachineAction.connection { case .createConnection(let connectionID, on: let eventLoop): self.unlocked.connection = .createConnection(connectionID, on: eventLoop) - case .scheduleBackoffTimer(let connectionID, backoff: let backoff, on: let eventLoop): + case .scheduleBackoffTimer(let connectionID, let backoff, on: let eventLoop): self.locked.connection = .scheduleBackoffTimer(connectionID, backoff: backoff, on: eventLoop) case .scheduleTimeoutTimer(let connectionID, on: let eventLoop): self.locked.connection = .scheduleTimeoutTimer(connectionID, on: eventLoop) case .cancelTimeoutTimer(let connectionID): self.locked.connection = .cancelTimeoutTimer(connectionID) - case .closeConnection(let connection, isShutdown: let isShutdown): + case .closeConnection(let connection, let isShutdown): self.unlocked.connection = .closeConnection(connection, isShutdown: isShutdown) - case .cleanupConnections(var cleanupContext, isShutdown: let isShutdown): + case .cleanupConnections(var cleanupContext, let isShutdown): // self.locked.connection = .cancelBackoffTimers(cleanupContext.connectBackoff) cleanupContext.connectBackoff = [] @@ -220,7 +227,7 @@ final class HTTPConnectionPool { private func runLockedConnectionAction(_ action: Actions.ConnectionAction.Locked) { switch action { - case .scheduleBackoffTimer(let connectionID, backoff: let backoff, on: let eventLoop): + case .scheduleBackoffTimer(let connectionID, let backoff, on: let eventLoop): self.scheduleConnectionStartBackoffTimer(connectionID, backoff, on: eventLoop) case .scheduleTimeoutTimer(let connectionID, on: let eventLoop): @@ -248,7 +255,7 @@ final class HTTPConnectionPool { self.cancelRequestTimeout(requestID) case .cancelRequestTimeouts(let requests): - requests.forEach { self.cancelRequestTimeout($0.id) } + for request in requests { self.cancelRequestTimeout(request.id) } case .none: break @@ -265,10 +272,13 @@ final class HTTPConnectionPool { case .createConnection(let connectionID, let eventLoop): self.createConnection(connectionID, on: eventLoop) - case .closeConnection(let connection, isShutdown: let isShutdown): - self.logger.trace("close connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - ]) + case .closeConnection(let connection, let isShutdown): + self.logger.trace( + "close connection", + metadata: [ + "ahc-connection-id": "\(connection.id)" + ] + ) // we are not interested in the close promise... connection.close(promise: nil) @@ -277,7 +287,7 @@ final class HTTPConnectionPool { self.delegate.connectionPoolDidShutdown(self, unclean: unclean) } - case .cleanupConnections(let cleanupContext, isShutdown: let isShutdown): + case .cleanupConnections(let cleanupContext, let isShutdown): for connection in cleanupContext.close { connection.close(promise: nil) } @@ -314,13 +324,15 @@ final class HTTPConnectionPool { connection.executeRequest(request.req) case .executeRequests(let requests, let connection): - requests.forEach { connection.executeRequest($0.req) } + for request in requests { + connection.executeRequest(request.req) + } case .failRequest(let request, let error): request.req.fail(error) case .failRequests(let requests, let error): - requests.forEach { $0.req.fail(error) } + for request in requests { request.req.fail(error) } case .none: break @@ -328,9 +340,12 @@ final class HTTPConnectionPool { } private func createConnection(_ connectionID: Connection.ID, on eventLoop: EventLoop) { - self.logger.trace("Opening fresh connection", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Opening fresh connection", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) // Even though this function is called make it actually creates/establishes a connection. // TBD: Should we rename it? To what? self.connectionFactory.makeConnection( @@ -373,9 +388,12 @@ final class HTTPConnectionPool { } private func scheduleIdleTimerForConnection(_ connectionID: Connection.ID, on eventLoop: EventLoop) { - self.logger.trace("Schedule idle connection timeout timer", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Schedule idle connection timeout timer", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) let scheduled = eventLoop.scheduleTask(in: self.idleConnectionTimeout) { // there might be a race between a cancelTimer call and the triggering // of this scheduled task. both want to acquire the lock @@ -393,9 +411,12 @@ final class HTTPConnectionPool { } private func cancelIdleTimerForConnection(_ connectionID: Connection.ID) { - self.logger.trace("Cancel idle connection timeout timer", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Cancel idle connection timeout timer", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) guard let cancelTimer = self._idleTimer.removeValue(forKey: connectionID) else { preconditionFailure("Expected to have an idle timer for connection \(connectionID) at this point.") } @@ -407,9 +428,12 @@ final class HTTPConnectionPool { _ timeAmount: TimeAmount, on eventLoop: EventLoop ) { - self.logger.trace("Schedule connection creation backoff timer", metadata: [ - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.trace( + "Schedule connection creation backoff timer", + metadata: [ + "ahc-connection-id": "\(connectionID)" + ] + ) let scheduled = eventLoop.scheduleTask(in: timeAmount) { // there might be a race between a backoffTimer and the pool shutting down. @@ -437,42 +461,54 @@ final class HTTPConnectionPool { // MARK: - Protocol methods - extension HTTPConnectionPool: HTTPConnectionRequester { - func http1ConnectionCreated(_ connection: HTTP1Connection) { - self.logger.trace("successfully created connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/1.1", - ]) + func http1ConnectionCreated(_ connection: HTTP1Connection.SendableView) { + self.logger.trace( + "successfully created connection", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/1.1", + ] + ) self.modifyStateAndRunActions { $0.newHTTP1ConnectionCreated(.http1_1(connection)) } } - func http2ConnectionCreated(_ connection: HTTP2Connection, maximumStreams: Int) { - self.logger.trace("successfully created connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - "ahc-max-streams": "\(maximumStreams)", - ]) + func http2ConnectionCreated(_ connection: HTTP2Connection.SendableView, maximumStreams: Int) { + self.logger.trace( + "successfully created connection", + metadata: [ + "ahc-connection-id": "\(connection.id)", + "ahc-http-version": "http/2", + "ahc-max-streams": "\(maximumStreams)", + ] + ) self.modifyStateAndRunActions { $0.newHTTP2ConnectionCreated(.http2(connection), maxConcurrentStreams: maximumStreams) } } func failedToCreateHTTPConnection(_ connectionID: HTTPConnectionPool.Connection.ID, error: Error) { - self.logger.debug("connection attempt failed", metadata: [ - "ahc-error": "\(error)", - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.debug( + "connection attempt failed", + metadata: [ + "ahc-error": "\(error)", + "ahc-connection-id": "\(connectionID)", + ] + ) self.modifyStateAndRunActions { $0.failedToCreateNewConnection(error, connectionID: connectionID) } } func waitingForConnectivity(_ connectionID: HTTPConnectionPool.Connection.ID, error: Error) { - self.logger.debug("waiting for connectivity", metadata: [ - "ahc-error": "\(error)", - "ahc-connection-id": "\(connectionID)", - ]) + self.logger.debug( + "waiting for connectivity", + metadata: [ + "ahc-error": "\(error)", + "ahc-connection-id": "\(connectionID)", + ] + ) self.modifyStateAndRunActions { $0.waitingForConnectivity(error, connectionID: connectionID) } @@ -480,66 +516,84 @@ extension HTTPConnectionPool: HTTPConnectionRequester { } extension HTTPConnectionPool: HTTP1ConnectionDelegate { - func http1ConnectionClosed(_ connection: HTTP1Connection) { - self.logger.debug("connection closed", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/1.1", - ]) + func http1ConnectionClosed(_ id: HTTPConnectionPool.Connection.ID) { + self.logger.debug( + "connection closed", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/1.1", + ] + ) self.modifyStateAndRunActions { - $0.http1ConnectionClosed(connection.id) + $0.http1ConnectionClosed(id) } } - func http1ConnectionReleased(_ connection: HTTP1Connection) { - self.logger.trace("releasing connection", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/1.1", - ]) + func http1ConnectionReleased(_ id: HTTPConnectionPool.Connection.ID) { + self.logger.trace( + "releasing connection", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/1.1", + ] + ) self.modifyStateAndRunActions { - $0.http1ConnectionReleased(connection.id) + $0.http1ConnectionReleased(id) } } } extension HTTPConnectionPool: HTTP2ConnectionDelegate { - func http2Connection(_ connection: HTTP2Connection, newMaxStreamSetting: Int) { - self.logger.debug("new max stream setting", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - "ahc-max-streams": "\(newMaxStreamSetting)", - ]) + func http2Connection(_ id: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) { + self.logger.debug( + "new max stream setting", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/2", + "ahc-max-streams": "\(newMaxStreamSetting)", + ] + ) self.modifyStateAndRunActions { - $0.newHTTP2MaxConcurrentStreamsReceived(connection.id, newMaxStreams: newMaxStreamSetting) + $0.newHTTP2MaxConcurrentStreamsReceived(id, newMaxStreams: newMaxStreamSetting) } } - func http2ConnectionGoAwayReceived(_ connection: HTTP2Connection) { - self.logger.debug("connection go away received", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - ]) + func http2ConnectionGoAwayReceived(_ id: HTTPConnectionPool.Connection.ID) { + self.logger.debug( + "connection go away received", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/2", + ] + ) self.modifyStateAndRunActions { - $0.http2ConnectionGoAwayReceived(connection.id) + $0.http2ConnectionGoAwayReceived(id) } } - func http2ConnectionClosed(_ connection: HTTP2Connection) { - self.logger.debug("connection closed", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - ]) + func http2ConnectionClosed(_ id: HTTPConnectionPool.Connection.ID) { + self.logger.debug( + "connection closed", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/2", + ] + ) self.modifyStateAndRunActions { - $0.http2ConnectionClosed(connection.id) + $0.http2ConnectionClosed(id) } } - func http2ConnectionStreamClosed(_ connection: HTTP2Connection, availableStreams: Int) { - self.logger.trace("stream closed", metadata: [ - "ahc-connection-id": "\(connection.id)", - "ahc-http-version": "http/2", - ]) + func http2ConnectionStreamClosed(_ id: HTTPConnectionPool.Connection.ID, availableStreams: Int) { + self.logger.trace( + "stream closed", + metadata: [ + "ahc-connection-id": "\(id)", + "ahc-http-version": "http/2", + ] + ) self.modifyStateAndRunActions { - $0.http2ConnectionStreamClosed(connection.id) + $0.http2ConnectionStreamClosed(id) } } } @@ -558,18 +612,18 @@ extension HTTPConnectionPool { typealias ID = Int private enum Reference { - case http1_1(HTTP1Connection) - case http2(HTTP2Connection) + case http1_1(HTTP1Connection.SendableView) + case http2(HTTP2Connection.SendableView) case __testOnly_connection(ID, EventLoop) } private let _ref: Reference - fileprivate static func http1_1(_ conn: HTTP1Connection) -> Self { + fileprivate static func http1_1(_ conn: HTTP1Connection.SendableView) -> Self { Connection(_ref: .http1_1(conn)) } - fileprivate static func http2(_ conn: HTTP2Connection) -> Self { + fileprivate static func http2(_ conn: HTTP2Connection.SendableView) -> Self { Connection(_ref: .http2(conn)) } @@ -641,7 +695,9 @@ extension HTTPConnectionPool { return lhsConn.id == rhsConn.id case (.http2(let lhsConn), .http2(let rhsConn)): return lhsConn.id == rhsConn.id - case (.__testOnly_connection(let lhsID, let lhsEventLoop), .__testOnly_connection(let rhsID, let rhsEventLoop)): + case ( + .__testOnly_connection(let lhsID, let lhsEventLoop), .__testOnly_connection(let rhsID, let rhsEventLoop) + ): return lhsID == rhsID && lhsEventLoop === rhsEventLoop default: return false @@ -722,7 +778,7 @@ struct EventLoopID: Hashable { } static func __testOnly_fakeID(_ id: Int) -> EventLoopID { - return EventLoopID(.__testOnly_fakeID(id)) + EventLoopID(.__testOnly_fakeID(id)) } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift index d64ceedd6..bce55eb5b 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPExecutableRequest.swift @@ -132,7 +132,7 @@ import NIOSSL /// /// Use this handle to cancel the request, while it is waiting for a free connection, to execute the request. /// This protocol is only intended to be implemented by the `HTTPConnectionPool`. -protocol HTTPRequestScheduler { +protocol HTTPRequestScheduler: Sendable { /// Informs the task queuer that a request has been cancelled. func cancelRequest(_: HTTPSchedulableRequest) } @@ -176,7 +176,7 @@ protocol HTTPSchedulableRequest: HTTPExecutableRequest { /// A handle to the request executor. /// /// This protocol is implemented by the `HTTP1ClientChannelHandler`. -protocol HTTPRequestExecutor { +protocol HTTPRequestExecutor: Sendable { /// Writes a body part into the channel pipeline /// /// This method may be **called on any thread**. The executor needs to ensure thread safety. @@ -201,7 +201,7 @@ protocol HTTPRequestExecutor { func cancelRequest(_ task: HTTPExecutableRequest) } -protocol HTTPExecutableRequest: AnyObject { +protocol HTTPExecutableRequest: AnyObject, Sendable { /// The request's logger var logger: Logger { get } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift index 90578bc87..5c5b893e0 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine+Demand.swift @@ -104,8 +104,8 @@ extension HTTPRequestStateMachine { // forwarded to the user. case .waitingForRead, - .waitingForDemand, - .waitingForReadOrDemand: + .waitingForDemand, + .waitingForReadOrDemand: return nil case .modifying: @@ -174,8 +174,8 @@ extension HTTPRequestStateMachine { return (buffer, .none) case .waitingForReadOrDemand(let buffer), - .waitingForRead(let buffer), - .waitingForDemand(let buffer): + .waitingForRead(let buffer), + .waitingForDemand(let buffer): // Normally this code path should never be hit. However there is one way to trigger // this: // diff --git a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift index b575ae094..e06389360 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/HTTPRequestStateMachine.swift @@ -58,7 +58,7 @@ struct HTTPRequestStateMachine { /// The request is streaming its request body. `expectedBodyLength` has a value, if the request header contained /// a `"content-length"` header field. If the request header contained a `"transfer-encoding" = "chunked"` /// header field, the `expectedBodyLength` is `nil`. - case streaming(expectedBodyLength: Int?, sentBodyBytes: Int, producer: ProducerControlState) + case streaming(expectedBodyLength: Int64?, sentBodyBytes: Int64, producer: ProducerControlState) /// The request has sent its request body and end. case endSent } @@ -161,10 +161,10 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, - .running(.streaming(_, _, producer: .producing), _), - .running(.endSent, _), - .finished, - .failed: + .running(.streaming(_, _, producer: .producing), _), + .running(.endSent, _), + .finished, + .failed: return .wait case .waitForChannelToBecomeWritable(let head, let metadata): @@ -196,11 +196,11 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.streaming(_, _, producer: .paused), _), - .running(.endSent, _), - .finished, - .failed: + .waitForChannelToBecomeWritable, + .running(.streaming(_, _, producer: .paused), _), + .running(.endSent, _), + .finished, + .failed: return .wait case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .producing), let responseState): @@ -219,13 +219,16 @@ struct HTTPRequestStateMachine { mutating func errorHappened(_ error: Error) -> Action { if let error = error as? NIOSSLError, - error == .uncleanShutdown, - let action = self.handleNIOSSLUncleanShutdownError() { + error == .uncleanShutdown, + let action = self.handleNIOSSLUncleanShutdownError() + { return action } switch self.state { case .initialized: - preconditionFailure("After the state machine has been initialized, start must be called immediately. Thus this state is unreachable") + preconditionFailure( + "After the state machine has been initialized, start must be called immediately. Thus this state is unreachable" + ) case .waitForChannelToBecomeWritable: // the request failed, before it was sent onto the wire. self.state = .failed(error) @@ -247,14 +250,14 @@ struct HTTPRequestStateMachine { private mutating func handleNIOSSLUncleanShutdownError() -> Action? { switch self.state { case .running(.streaming, .waitingForHead), - .running(.endSent, .waitingForHead): + .running(.endSent, .waitingForHead): // if we received a NIOSSL.uncleanShutdown before we got an answer we should handle // this like a normal connection close. We will receive a call to channelInactive after // this error. return .wait case .running(.streaming, .receivingBody(let responseHead, _)), - .running(.endSent, .receivingBody(let responseHead, _)): + .running(.endSent, .receivingBody(let responseHead, _)): // This code is only reachable for request and responses, which we expect to have a body. // We depend on logic from the HTTPResponseDecoder here. The decoder will emit an // HTTPResponsePart.end right after the HTTPResponsePart.head, for every request with a @@ -263,7 +266,9 @@ struct HTTPRequestStateMachine { // For this reason we only need to check the "content-length" or "transfer-encoding" // headers here to determine if we are potentially in an EOF terminated response. - if responseHead.headers.contains(name: "content-length") || responseHead.headers.contains(name: "transfer-encoding") { + if responseHead.headers.contains(name: "content-length") + || responseHead.headers.contains(name: "transfer-encoding") + { // If we have already received the response head, the parser will ensure that we // receive a complete response, if the content-length or transfer-encoding header // was set. In this case we can ignore the NIOSSLError.uncleanShutdown. We will see @@ -285,9 +290,11 @@ struct HTTPRequestStateMachine { mutating func requestStreamPartReceived(_ part: IOData, promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.endSent, _): - preconditionFailure("We must be in the request streaming phase, if we receive further body parts. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable, + .running(.endSent, _): + preconditionFailure( + "We must be in the request streaming phase, if we receive further body parts. Invalid state: \(self.state)" + ) case .running(.streaming(_, _, let producerState), .receivingBody(let head, _)) where head.status.code >= 300: // If we have already received a response head with status >= 300, we won't send out any @@ -308,13 +315,13 @@ struct HTTPRequestStateMachine { // pause. The reason for this is as follows: There might be thread synchronization // situations in which the producer might not have received the plea to pause yet. - if let expected = expectedBodyLength, sentBodyBytes + part.readableBytes > expected { + if let expected = expectedBodyLength, sentBodyBytes + Int64(part.readableBytes) > expected { let error = HTTPClientError.bodyLengthMismatch self.state = .failed(error) return .failRequest(error, .close(promise)) } - sentBodyBytes += part.readableBytes + sentBodyBytes += Int64(part.readableBytes) let requestState: RequestState = .streaming( expectedBodyLength: expectedBodyLength, @@ -349,9 +356,11 @@ struct HTTPRequestStateMachine { mutating func requestStreamFinished(promise: EventLoopPromise?) -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.endSent, _): - preconditionFailure("A request body stream end is only expected if we are in state request streaming. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable, + .running(.endSent, _): + preconditionFailure( + "A request body stream end is only expected if we are in state request streaming. Invalid state: \(self.state)" + ) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .waitingForHead): if let expected = expectedBodyLength, expected != sentBodyBytes { @@ -363,7 +372,10 @@ struct HTTPRequestStateMachine { self.state = .running(.endSent, .waitingForHead) return .sendRequestEnd(promise) - case .running(.streaming(let expectedBodyLength, let sentBodyBytes, _), .receivingBody(let head, let streamState)): + case .running( + .streaming(let expectedBodyLength, let sentBodyBytes, _), + .receivingBody(let head, let streamState) + ): assert(head.status.code < 300) if let expected = expectedBodyLength, expected != sentBodyBytes { @@ -456,11 +468,11 @@ struct HTTPRequestStateMachine { mutating func read() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(_, .waitingForHead), - .running(_, .endReceived), - .finished, - .failed: + .waitForChannelToBecomeWritable, + .running(_, .waitingForHead), + .running(_, .endReceived), + .finished, + .failed: // If we are not in the middle of streaming the response body, we always want to get // more data... return .read @@ -493,11 +505,11 @@ struct HTTPRequestStateMachine { mutating func channelReadComplete() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(_, .waitingForHead), - .running(_, .endReceived), - .finished, - .failed: + .waitForChannelToBecomeWritable, + .running(_, .waitingForHead), + .running(_, .endReceived), + .finished, + .failed: return .wait case .running(let requestState, .receivingBody(let head, var streamState)): @@ -528,7 +540,9 @@ struct HTTPRequestStateMachine { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before sending a request head ourselves \(self.state)") + preconditionFailure( + "How can we receive a response head before sending a request head ourselves \(self.state)" + ) case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), .waitingForHead): self.state = .running( @@ -546,7 +560,11 @@ struct HTTPRequestStateMachine { return .forwardResponseHead(head, pauseRequestBodyStream: true) } else { self.state = .running( - .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: .producing), + .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: .producing + ), .receivingBody(head, .init()) ) return .forwardResponseHead(head, pauseRequestBodyStream: false) @@ -557,7 +575,9 @@ struct HTTPRequestStateMachine { return .forwardResponseHead(head, pauseRequestBodyStream: false) case .running(_, .receivingBody), .running(_, .endReceived), .finished: - preconditionFailure("How can we successfully finish the request, before having received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we successfully finish the request, before having received a head. Invalid state: \(self.state)" + ) case .failed: return .wait @@ -569,10 +589,14 @@ struct HTTPRequestStateMachine { mutating func receivedHTTPResponseBodyPart(_ body: ByteBuffer) -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response head before completely sending a request head ourselves. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response head before completely sending a request head ourselves. Invalid state: \(self.state)" + ) case .running(_, .waitingForHead): - preconditionFailure("How can we receive a response body, if we haven't received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response body, if we haven't received a head. Invalid state: \(self.state)" + ) case .running(let requestState, .receivingBody(let head, var responseStreamState)): return self.avoidingStateMachineCoW { state -> Action in @@ -582,7 +606,9 @@ struct HTTPRequestStateMachine { } case .running(_, .endReceived), .finished: - preconditionFailure("How can we successfully finish the request, before having received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we successfully finish the request, before having received a head. Invalid state: \(self.state)" + ) case .failed: return .wait @@ -595,20 +621,31 @@ struct HTTPRequestStateMachine { private mutating func receivedHTTPResponseEnd() -> Action { switch self.state { case .initialized, .waitForChannelToBecomeWritable: - preconditionFailure("How can we receive a response end before completely sending a request head ourselves. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response end before completely sending a request head ourselves. Invalid state: \(self.state)" + ) case .running(_, .waitingForHead): - preconditionFailure("How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response end, if we haven't a received a head. Invalid state: \(self.state)" + ) - case .running(.streaming(let expectedBodyLength, let sentBodyBytes, let producerState), .receivingBody(let head, var responseStreamState)) - where head.status.code < 300: + case .running( + .streaming(let expectedBodyLength, let sentBodyBytes, let producerState), + .receivingBody(let head, var responseStreamState) + ) + where head.status.code < 300: return self.avoidingStateMachineCoW { state -> Action in let (remainingBuffer, connectionAction) = responseStreamState.end() switch connectionAction { case .none: state = .running( - .streaming(expectedBodyLength: expectedBodyLength, sentBodyBytes: sentBodyBytes, producer: producerState), + .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: producerState + ), .endReceived ) return .forwardResponseBodyParts(remainingBuffer) @@ -624,7 +661,10 @@ struct HTTPRequestStateMachine { case .running(.streaming(_, _, let producerState), .receivingBody(let head, var responseStreamState)): assert(head.status.code >= 300) - assert(producerState == .paused, "Expected to have paused the request body stream, when the head was received. Invalid state: \(self.state)") + assert( + producerState == .paused, + "Expected to have paused the request body stream, when the head was received. Invalid state: \(self.state)" + ) return self.avoidingStateMachineCoW { state -> Action in // We can ignore the connectionAction from the responseStreamState, since the @@ -647,7 +687,9 @@ struct HTTPRequestStateMachine { } case .running(_, .endReceived), .finished: - preconditionFailure("How can we receive a response end, if another one was already received. Invalid state: \(self.state)") + preconditionFailure( + "How can we receive a response end, if another one was already received. Invalid state: \(self.state)" + ) case .failed: return .wait @@ -660,9 +702,11 @@ struct HTTPRequestStateMachine { mutating func demandMoreResponseBodyParts() -> Action { switch self.state { case .initialized, - .running(_, .waitingForHead), - .waitForChannelToBecomeWritable: - preconditionFailure("The response is expected to only ask for more data after the response head was forwarded \(self.state)") + .running(_, .waitingForHead), + .waitForChannelToBecomeWritable: + preconditionFailure( + "The response is expected to only ask for more data after the response head was forwarded \(self.state)" + ) case .running(let requestState, .receivingBody(let head, var responseStreamState)): return self.avoidingStateMachineCoW { state -> Action in @@ -672,8 +716,8 @@ struct HTTPRequestStateMachine { } case .running(_, .endReceived), - .finished, - .failed: + .finished, + .failed: return .wait case .modifying: @@ -684,9 +728,11 @@ struct HTTPRequestStateMachine { mutating func idleReadTimeoutTriggered() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable, - .running(.streaming, _): - preconditionFailure("We only schedule idle read timeouts after we have sent the complete request. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable, + .running(.streaming, _): + preconditionFailure( + "We only schedule idle read timeouts after we have sent the complete request. Invalid state: \(self.state)" + ) case .running(.endSent, .waitingForHead), .running(.endSent, .receivingBody): let error = HTTPClientError.readTimeout @@ -707,8 +753,10 @@ struct HTTPRequestStateMachine { mutating func idleWriteTimeoutTriggered() -> Action { switch self.state { case .initialized, - .waitForChannelToBecomeWritable: - preconditionFailure("We only schedule idle write timeouts while the request is being sent. Invalid state: \(self.state)") + .waitForChannelToBecomeWritable: + preconditionFailure( + "We only schedule idle write timeouts while the request is being sent. Invalid state: \(self.state)" + ) case .running(.streaming, _): let error = HTTPClientError.writeTimeout @@ -733,7 +781,10 @@ struct HTTPRequestStateMachine { self.state = .running(.endSent, .waitingForHead) return .sendRequestHead(head, sendEnd: true) } else { - self.state = .running(.streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .paused), .waitingForHead) + self.state = .running( + .streaming(expectedBodyLength: length, sentBodyBytes: 0, producer: .paused), + .waitingForHead + ) return .sendRequestHead(head, sendEnd: false) } } @@ -745,11 +796,14 @@ struct HTTPRequestStateMachine { case .running(.streaming(let expectedBodyLength, let sentBodyBytes, producer: .paused), let responseState): let startProducing = self.isChannelWritable && expectedBodyLength != sentBodyBytes - self.state = .running(.streaming( - expectedBodyLength: expectedBodyLength, - sentBodyBytes: sentBodyBytes, - producer: startProducing ? .producing : .paused - ), responseState) + self.state = .running( + .streaming( + expectedBodyLength: expectedBodyLength, + sentBodyBytes: sentBodyBytes, + producer: startProducing ? .producing : .paused + ), + responseState + ) return .notifyRequestHeadSendSuccessfully( resumeRequestBodyStream: startProducing, startIdleTimer: false @@ -757,7 +811,9 @@ struct HTTPRequestStateMachine { case .running(.endSent, _): return .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) case .running(.streaming(_, _, producer: .producing), _): - preconditionFailure("request body producing can not start before we have successfully send the header \(self.state)") + preconditionFailure( + "request body producing can not start before we have successfully send the header \(self.state)" + ) case .failed: return .wait @@ -768,7 +824,7 @@ struct HTTPRequestStateMachine { } extension RequestFramingMetadata.Body { - var expectedLength: Int? { + var expectedLength: Int64? { switch self { case .fixedSize(let length): return length case .stream: return nil @@ -830,7 +886,8 @@ extension HTTPRequestStateMachine: CustomStringConvertible { case .waitForChannelToBecomeWritable: return "HTTPRequestStateMachine(.waitForChannelToBecomeWritable, isWritable: \(self.isChannelWritable))" case .running(let requestState, let responseState): - return "HTTPRequestStateMachine(.running(request: \(requestState), response: \(responseState)), isWritable: \(self.isChannelWritable))" + return + "HTTPRequestStateMachine(.running(request: \(requestState), response: \(responseState)), isWritable: \(self.isChannelWritable))" case .finished: return "HTTPRequestStateMachine(.finished, isWritable: \(self.isChannelWritable))" case .failed(let error): @@ -844,7 +901,7 @@ extension HTTPRequestStateMachine: CustomStringConvertible { extension HTTPRequestStateMachine.RequestState: CustomStringConvertible { var description: String { switch self { - case .streaming(expectedBodyLength: let expected, let sent, producer: let producer): + case .streaming(expectedBodyLength: let expected, let sent, let producer): return ".streaming(sent: \(expected != nil ? String(expected!) : "-"), sent: \(sent), producer: \(producer)" case .endSent: return ".endSent" diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift index 83f0e6edf..58ba694a7 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestBodyLength.swift @@ -20,5 +20,5 @@ internal enum RequestBodyLength: Hashable, Sendable { /// size of the request body is not known before starting the request case unknown /// size of the request body is fixed and exactly `count` bytes - case known(_ count: Int) + case known(_ count: Int64) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift b/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift index 98080e364..033060a99 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/RequestFramingMetadata.swift @@ -15,7 +15,7 @@ struct RequestFramingMetadata: Hashable { enum Body: Hashable { case stream - case fixedSize(Int) + case fixedSize(Int64) } var connectionClose: Bool diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift index cc7c7cfa1..71d8f15f1 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+Backoff.swift @@ -13,10 +13,13 @@ //===----------------------------------------------------------------------===// import NIOCore + #if canImport(Darwin) import func Darwin.pow #elseif canImport(Musl) import func Musl.pow +#elseif canImport(Android) +import func Android.pow #else import func Glibc.pow #endif diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift index 935cdb2f6..15138a141 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1Connections.swift @@ -71,7 +71,7 @@ extension HTTPConnectionPool { var idleAndNoRemainingUses: Bool { switch self.state { - case .idle(_, since: _, remainingUses: let remainingUses): + case .idle(_, since: _, let remainingUses): if let remainingUses = remainingUses { return remainingUses <= 0 } else { @@ -139,7 +139,7 @@ extension HTTPConnectionPool { mutating func lease() -> Connection { switch self.state { - case .idle(let connection, since: _, remainingUses: let remainingUses): + case .idle(let connection, since: _, let remainingUses): self.state = .leased(connection, remainingUses: remainingUses.map { $0 - 1 }) return connection case .backingOff, .starting, .leased, .closed: @@ -208,7 +208,9 @@ extension HTTPConnectionPool { context.cancel.append(connection) return .keepConnection case .closed: - preconditionFailure("Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)") + preconditionFailure( + "Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)" + ) } } @@ -232,7 +234,9 @@ extension HTTPConnectionPool { case .leased: return .keepConnection case .closed: - preconditionFailure("Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)") + preconditionFailure( + "Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)" + ) } } } @@ -261,7 +265,7 @@ extension HTTPConnectionPool { init(maximumConcurrentConnections: Int, generator: Connection.ID.Generator, maximumConnectionUses: Int?) { self.connections = [] - self.connections.reserveCapacity(maximumConcurrentConnections) + self.connections.reserveCapacity(min(maximumConcurrentConnections, 1024)) self.overflowIndex = self.connections.endIndex self.maximumConcurrentConnections = maximumConcurrentConnections self.generator = generator @@ -307,7 +311,7 @@ extension HTTPConnectionPool { } private var maximumAdditionalGeneralPurposeConnections: Int { - self.maximumConcurrentConnections - (self.overflowIndex - 1) + self.maximumConcurrentConnections - (self.overflowIndex) } /// Is there at least one connection that is able to run requests @@ -316,7 +320,7 @@ extension HTTPConnectionPool { } func startingEventLoopConnections(on eventLoop: EventLoop) -> Int { - return self.connections[self.overflowIndex.. [(Connection.ID, EventLoop)] { // create new connections for requests with a required event loop - // we may already start connections for those requests and do not want to start to many + // we may already start connections for those requests and do not want to start too many let startingRequiredEventLoopConnectionCount = Dictionary( self.connections[self.overflowIndex.. [(Connection.ID, EventLoop)] in // We need a connection for each queued request with a required event loop. // Therefore, we look how many request we have queued for a given `eventLoop` and // how many connections we are already starting on the given `eventLoop`. // If we have not enough, we will create additional connections to have at least // on connection per request. - let connectionsToStart = requestCount - startingRequiredEventLoopConnectionCount[eventLoop.id, default: 0] + let connectionsToStart = + requestCount - startingRequiredEventLoopConnectionCount[eventLoop.id, default: 0] return stride(from: 0, to: connectionsToStart, by: 1).lazy.map { _ in (self.createNewOverflowConnection(on: eventLoop), eventLoop) } @@ -666,7 +677,8 @@ extension HTTPConnectionPool { // event loop we will continue with the event loop with the second most queued requests // and so on and so forth. The `generalPurposeRequestCountGroupedByPreferredEventLoop` // array is already ordered so we can just iterate over it without sorting by request count. - let newGeneralPurposeConnections: [(Connection.ID, EventLoop)] = generalPurposeRequestCountGroupedByPreferredEventLoop + let newGeneralPurposeConnections: [(Connection.ID, EventLoop)] = + generalPurposeRequestCountGroupedByPreferredEventLoop // we do not want to allocated intermediate arrays. .lazy // we flatten the grouped list of event loops by lazily repeating the event loop diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift index 2629b0ea2..09b1dc85e 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP1StateMachine.swift @@ -80,7 +80,10 @@ extension HTTPConnectionPool { requests: RequestQueue ) -> ConnectionMigrationAction { precondition(self.connections.isEmpty, "expected an empty state machine but connections are not empty") - precondition(self.http2Connections == nil, "expected an empty state machine but http2Connections are not nil") + precondition( + self.http2Connections == nil, + "expected an empty state machine but http2Connections are not nil" + ) precondition(self.requests.isEmpty, "expected an empty state machine but requests are not empty") self.requests = requests @@ -100,7 +103,8 @@ extension HTTPConnectionPool { let createConnections = self.connections.createConnectionsAfterMigrationIfNeeded( requiredEventLoopOfPendingRequests: requests.requestCountGroupedByRequiredEventLoop(), - generalPurposeRequestCountGroupedByPreferredEventLoop: requests.generalPurposeRequestCountGroupedByPreferredEventLoop() + generalPurposeRequestCountGroupedByPreferredEventLoop: + requests.generalPurposeRequestCountGroupedByPreferredEventLoop() ) if !http2Connections.isEmpty { @@ -229,7 +233,9 @@ extension HTTPConnectionPool { case .running: guard self.retryConnectionEstablishment else { guard let (index, _) = self.connections.failConnection(connectionID) else { - preconditionFailure("A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost.") + preconditionFailure( + "A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost." + ) } self.connections.removeConnection(at: index) @@ -295,7 +301,10 @@ extension HTTPConnectionPool { return .none } - precondition(self.lifecycleState == .running, "If we are shutting down, we must not have any idle connections") + precondition( + self.lifecycleState == .running, + "If we are shutting down, we must not have any idle connections" + ) return .init( request: .none, @@ -561,7 +570,8 @@ extension HTTPConnectionPool { // MARK: HTTP2 - mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action { + mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action + { // The `http2Connections` are optional here: // Connections report events back to us, if they are in a shutdown that was // initiated by the state machine. For this reason this callback might be invoked @@ -663,6 +673,7 @@ extension HTTPConnectionPool.HTTP1StateMachine: CustomStringConvertible { let stats = self.connections.stats let queued = self.requests.count - return "connections: [connecting: \(stats.connecting) | backoff: \(stats.backingOff) | leased: \(stats.leased) | idle: \(stats.idle)], queued: \(queued)" + return + "connections: [connecting: \(stats.connecting) | backoff: \(stats.backingOff) | leased: \(stats.leased) | idle: \(stats.idle)], queued: \(queued)" } } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift index 01d68b8e4..dbb6b2d30 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2Connections.swift @@ -117,7 +117,13 @@ extension HTTPConnectionPool { preconditionFailure("Invalid state: \(self.state)") case .starting(let maxUses): - self.state = .active(conn, maxStreams: maxStreams, usedStreams: 0, lastIdle: .now(), remainingUses: maxUses) + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: 0, + lastIdle: .now(), + remainingUses: maxUses + ) if let maxUses = maxUses { return min(maxStreams, maxUses) } else { @@ -136,7 +142,13 @@ extension HTTPConnectionPool { preconditionFailure("Invalid state for updating max concurrent streams: \(self.state)") case .active(let conn, _, let usedStreams, let lastIdle, let remainingUses): - self.state = .active(conn, maxStreams: maxStreams, usedStreams: usedStreams, lastIdle: lastIdle, remainingUses: remainingUses) + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: usedStreams, + lastIdle: lastIdle, + remainingUses: remainingUses + ) let availableStreams = max(maxStreams - usedStreams, 0) if let remainingUses = remainingUses { return min(remainingUses, availableStreams) @@ -192,8 +204,17 @@ extension HTTPConnectionPool { case .active(let conn, let maxStreams, var usedStreams, let lastIdle, let remainingUses): usedStreams += count precondition(usedStreams <= maxStreams, "tried to lease a connection which is not available") - precondition(remainingUses.map { $0 >= count } ?? true, "tried to lease streams from a connection which does not have enough remaining streams") - self.state = .active(conn, maxStreams: maxStreams, usedStreams: usedStreams, lastIdle: lastIdle, remainingUses: remainingUses.map { $0 - count }) + precondition( + remainingUses.map { $0 >= count } ?? true, + "tried to lease streams from a connection which does not have enough remaining streams" + ) + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: usedStreams, + lastIdle: lastIdle, + remainingUses: remainingUses.map { $0 - count } + ) return conn } } @@ -212,7 +233,13 @@ extension HTTPConnectionPool { lastIdle = .now() } - self.state = .active(conn, maxStreams: maxStreams, usedStreams: usedStreams, lastIdle: lastIdle, remainingUses: remainingUses) + self.state = .active( + conn, + maxStreams: maxStreams, + usedStreams: usedStreams, + lastIdle: lastIdle, + remainingUses: remainingUses + ) let availableStreams = max(maxStreams &- usedStreams, 0) if let remainingUses = remainingUses { return min(availableStreams, remainingUses) @@ -282,7 +309,9 @@ extension HTTPConnectionPool { return .keepConnection case .closed: - preconditionFailure("Unexpected state for cleanup: Did not expect to have closed connections in the state machine.") + preconditionFailure( + "Unexpected state for cleanup: Did not expect to have closed connections in the state machine." + ) } } @@ -341,7 +370,9 @@ extension HTTPConnectionPool { return .removeConnection case .closed: - preconditionFailure("Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)") + preconditionFailure( + "Unexpected state: Did not expect to have connections with this state in the state machine: \(self.state)" + ) } } @@ -388,16 +419,20 @@ extension HTTPConnectionPool { backingOff: [(Connection.ID, EventLoop)] ) { for (connectionID, eventLoop) in starting { - let newConnection = HTTP2ConnectionState(connectionID: connectionID, - eventLoop: eventLoop, - maximumUses: self.maximumConnectionUses) + let newConnection = HTTP2ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.append(newConnection) } for (connectionID, eventLoop) in backingOff { - var backingOffConnection = HTTP2ConnectionState(connectionID: connectionID, - eventLoop: eventLoop, - maximumUses: self.maximumConnectionUses) + var backingOffConnection = HTTP2ConnectionState( + connectionID: connectionID, + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) // TODO: Maybe we want to add a static init for backing off connections to HTTP2ConnectionState backingOffConnection.failedToConnect() self.connections.append(backingOffConnection) @@ -503,9 +538,11 @@ extension HTTPConnectionPool { "we should not create more than one connection per event loop" ) - let connection = HTTP2ConnectionState(connectionID: self.generator.next(), - eventLoop: eventLoop, - maximumUses: self.maximumConnectionUses) + let connection = HTTP2ConnectionState( + connectionID: self.generator.next(), + eventLoop: eventLoop, + maximumUses: self.maximumConnectionUses + ) self.connections.append(connection) return connection.connectionID } @@ -518,11 +555,17 @@ extension HTTPConnectionPool { /// - Returns: An index and an ``EstablishedConnectionContext`` to determine the next action for the now idle connection. /// Call ``leaseStreams(at:count:)`` or ``closeConnection(at:)`` with the supplied index after /// this. - mutating func newHTTP2ConnectionEstablished(_ connection: Connection, maxConcurrentStreams: Int) -> (Int, EstablishedConnectionContext) { + mutating func newHTTP2ConnectionEstablished( + _ connection: Connection, + maxConcurrentStreams: Int + ) -> (Int, EstablishedConnectionContext) { guard let index = self.connections.firstIndex(where: { $0.connectionID == connection.id }) else { preconditionFailure("There is a new connection that we didn't request!") } - precondition(connection.eventLoop === self.connections[index].eventLoop, "Expected the new connection to be on EL") + precondition( + connection.eventLoop === self.connections[index].eventLoop, + "Expected the new connection to be on EL" + ) let availableStreams = self.connections[index].connected(connection, maxStreams: maxConcurrentStreams) let context = EstablishedConnectionContext( availableStreams: availableStreams, diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift index 83a7647f4..2372cab4b 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+HTTP2StateMachine.swift @@ -47,8 +47,10 @@ extension HTTPConnectionPool { self.idGenerator = idGenerator self.requests = RequestQueue() - self.connections = HTTP2Connections(generator: idGenerator, - maximumConnectionUses: maximumConnectionUses) + self.connections = HTTP2Connections( + generator: idGenerator, + maximumConnectionUses: maximumConnectionUses + ) self.lifecycleState = lifecycleState self.retryConnectionEstablishment = retryConnectionEstablishment } @@ -83,7 +85,10 @@ extension HTTPConnectionPool { requests: RequestQueue ) -> ConnectionMigrationAction { precondition(self.connections.isEmpty, "expected an empty state machine but connections are not empty") - precondition(self.http1Connections == nil, "expected an empty state machine but http1Connections are not nil") + precondition( + self.http1Connections == nil, + "expected an empty state machine but http1Connections are not nil" + ) precondition(self.requests.isEmpty, "expected an empty state machine but requests are not empty") self.requests = requests @@ -93,7 +98,7 @@ extension HTTPConnectionPool { self.connections = http2Connections } - var http1Connections = http1Connections // make http1Connections mutable + var http1Connections = http1Connections // make http1Connections mutable let context = http1Connections.migrateToHTTP2() self.connections.migrateFromHTTP1( starting: context.starting, @@ -215,7 +220,10 @@ extension HTTPConnectionPool { .init(self._newHTTP2ConnectionEstablished(connection, maxConcurrentStreams: maxConcurrentStreams)) } - private mutating func _newHTTP2ConnectionEstablished(_ connection: Connection, maxConcurrentStreams: Int) -> EstablishedAction { + private mutating func _newHTTP2ConnectionEstablished( + _ connection: Connection, + maxConcurrentStreams: Int + ) -> EstablishedAction { self.failedConsecutiveConnectionAttempts = 0 self.lastConnectFailure = nil if self.connections.hasActiveConnection(for: connection.eventLoop) { @@ -296,8 +304,14 @@ extension HTTPConnectionPool { } } - mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action { - guard let (index, context) = self.connections.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) else { + mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action + { + guard + let (index, context) = self.connections.newHTTP2MaxConcurrentStreamsReceived( + connectionID, + newMaxStreams: newMaxStreams + ) + else { // When a connection close is initiated by the connection pool, the connection will // still report further events (like newMaxConcurrentStreamsReceived) to the state // machine. In those cases we must ignore the event. @@ -341,15 +355,15 @@ extension HTTPConnectionPool { // we need to start a new on connection in two cases: let needGeneralPurposeConnection = // 1. if we have general purpose requests - !self.requests.isEmpty(for: nil) && + !self.requests.isEmpty(for: nil) // and no connection starting or active - !context.hasGeneralPurposeConnection + && !context.hasGeneralPurposeConnection let needRequiredEventLoopConnection = // 2. or if we have requests for a required event loop - !self.requests.isEmpty(for: eventLoop) && + !self.requests.isEmpty(for: eventLoop) // and no connection starting or active for the given event loop - !context.hasConnectionOnSpecifiedEventLoop + && !context.hasConnectionOnSpecifiedEventLoop guard needGeneralPurposeConnection || needRequiredEventLoopConnection else { // otherwise we can remove the connection @@ -357,7 +371,8 @@ extension HTTPConnectionPool { return .none } - let (newConnectionID, previousEventLoop) = self.connections.createNewConnectionByReplacingClosedConnection(at: index) + let (newConnectionID, previousEventLoop) = self.connections + .createNewConnectionByReplacingClosedConnection(at: index) precondition(previousEventLoop === eventLoop) return .init( @@ -413,7 +428,9 @@ extension HTTPConnectionPool { case .running: guard self.retryConnectionEstablishment else { guard let (index, _) = self.connections.failConnection(connectionID) else { - preconditionFailure("A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost.") + preconditionFailure( + "A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost." + ) } self.connections.removeConnection(at: index) @@ -425,10 +442,15 @@ extension HTTPConnectionPool { let eventLoop = self.connections.backoffNextConnectionAttempt(connectionID) let backoff = calculateBackoff(failedAttempt: self.failedConsecutiveConnectionAttempts) - return .init(request: .none, connection: .scheduleBackoffTimer(connectionID, backoff: backoff, on: eventLoop)) + return .init( + request: .none, + connection: .scheduleBackoffTimer(connectionID, backoff: backoff, on: eventLoop) + ) case .shuttingDown: guard let (index, context) = self.connections.failConnection(connectionID) else { - preconditionFailure("A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost.") + preconditionFailure( + "A connection attempt failed, that the state machine knows nothing about. Somewhere state was lost." + ) } return self.nextActionForFailedConnection(at: index, on: context.eventLoop) case .shutDown: @@ -505,7 +527,10 @@ extension HTTPConnectionPool { return .none } - precondition(self.lifecycleState == .running, "If we are shutting down, we must not have any idle connections") + precondition( + self.lifecycleState == .running, + "If we are shutting down, we must not have any idle connections" + ) return .init( request: .none, @@ -558,7 +583,10 @@ extension HTTPConnectionPool { case .shuttingDown(let unclean): if self.connections.isEmpty { // if the http2connections are empty as well, there are no more connections. Shutdown completed. - return .init(request: .none, connection: .closeConnection(connection, isShutdown: .yes(unclean: unclean))) + return .init( + request: .none, + connection: .closeConnection(connection, isShutdown: .yes(unclean: unclean)) + ) } else { return .init(request: .none, connection: .closeConnection(connection, isShutdown: .no)) } diff --git a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift index a61471a69..0cc02cf0f 100644 --- a/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift +++ b/Sources/AsyncHTTPClient/ConnectionPool/State Machine/HTTPConnectionPool+StateMachine.swift @@ -26,7 +26,9 @@ extension HTTPConnectionPool { self.connection = connection } - static let none = Action(request: .none, connection: .none) + static var none: Action { + Action(request: .none, connection: .none) + } } enum ConnectionAction { @@ -105,28 +107,43 @@ extension HTTPConnectionPool { idGenerator: Connection.ID.Generator, maximumConcurrentHTTP1Connections: Int, retryConnectionEstablishment: Bool, + preferHTTP1: Bool, maximumConnectionUses: Int? ) { self.maximumConcurrentHTTP1Connections = maximumConcurrentHTTP1Connections self.retryConnectionEstablishment = retryConnectionEstablishment self.idGenerator = idGenerator self.maximumConnectionUses = maximumConnectionUses - let http1State = HTTP1StateMachine( - idGenerator: idGenerator, - maximumConcurrentConnections: maximumConcurrentHTTP1Connections, - retryConnectionEstablishment: retryConnectionEstablishment, - maximumConnectionUses: maximumConnectionUses, - lifecycleState: .running - ) - self.state = .http1(http1State) + + if preferHTTP1 { + let http1State = HTTP1StateMachine( + idGenerator: idGenerator, + maximumConcurrentConnections: maximumConcurrentHTTP1Connections, + retryConnectionEstablishment: retryConnectionEstablishment, + maximumConnectionUses: maximumConnectionUses, + lifecycleState: .running + ) + self.state = .http1(http1State) + } else { + let http2State = HTTP2StateMachine( + idGenerator: idGenerator, + retryConnectionEstablishment: retryConnectionEstablishment, + lifecycleState: .running, + maximumConnectionUses: maximumConnectionUses + ) + self.state = .http2(http2State) + } } mutating func executeRequest(_ request: Request) -> Action { - self.state.modify(http1: { http1 in - http1.executeRequest(request) - }, http2: { http2 in - http2.executeRequest(request) - }) + self.state.modify( + http1: { http1 in + http1.executeRequest(request) + }, + http2: { http2 in + http2.executeRequest(request) + } + ) } mutating func newHTTP1ConnectionCreated(_ connection: Connection) -> Action { @@ -187,60 +204,82 @@ extension HTTPConnectionPool { } } - mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action { - self.state.modify(http1: { http1 in - http1.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) - }, http2: { http2 in - http2.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) - }) + mutating func newHTTP2MaxConcurrentStreamsReceived(_ connectionID: Connection.ID, newMaxStreams: Int) -> Action + { + self.state.modify( + http1: { http1 in + http1.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) + }, + http2: { http2 in + http2.newHTTP2MaxConcurrentStreamsReceived(connectionID, newMaxStreams: newMaxStreams) + } + ) } mutating func http2ConnectionGoAwayReceived(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http2ConnectionGoAwayReceived(connectionID) - }, http2: { http2 in - http2.http2ConnectionGoAwayReceived(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http2ConnectionGoAwayReceived(connectionID) + }, + http2: { http2 in + http2.http2ConnectionGoAwayReceived(connectionID) + } + ) } mutating func http2ConnectionClosed(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http2ConnectionClosed(connectionID) - }, http2: { http2 in - http2.http2ConnectionClosed(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http2ConnectionClosed(connectionID) + }, + http2: { http2 in + http2.http2ConnectionClosed(connectionID) + } + ) } mutating func http2ConnectionStreamClosed(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http2ConnectionStreamClosed(connectionID) - }, http2: { http2 in - http2.http2ConnectionStreamClosed(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http2ConnectionStreamClosed(connectionID) + }, + http2: { http2 in + http2.http2ConnectionStreamClosed(connectionID) + } + ) } mutating func failedToCreateNewConnection(_ error: Error, connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.failedToCreateNewConnection(error, connectionID: connectionID) - }, http2: { http2 in - http2.failedToCreateNewConnection(error, connectionID: connectionID) - }) + self.state.modify( + http1: { http1 in + http1.failedToCreateNewConnection(error, connectionID: connectionID) + }, + http2: { http2 in + http2.failedToCreateNewConnection(error, connectionID: connectionID) + } + ) } mutating func waitingForConnectivity(_ error: Error, connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.waitingForConnectivity(error, connectionID: connectionID) - }, http2: { http2 in - http2.waitingForConnectivity(error, connectionID: connectionID) - }) + self.state.modify( + http1: { http1 in + http1.waitingForConnectivity(error, connectionID: connectionID) + }, + http2: { http2 in + http2.waitingForConnectivity(error, connectionID: connectionID) + } + ) } mutating func connectionCreationBackoffDone(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.connectionCreationBackoffDone(connectionID) - }, http2: { http2 in - http2.connectionCreationBackoffDone(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.connectionCreationBackoffDone(connectionID) + }, + http2: { http2 in + http2.connectionCreationBackoffDone(connectionID) + } + ) } /// A request has timed out. @@ -249,11 +288,14 @@ extension HTTPConnectionPool { /// request, but don't need to cancel the timer (it already triggered). If a request is cancelled /// we don't need to fail it but we need to cancel its timeout timer. mutating func timeoutRequest(_ requestID: Request.ID) -> Action { - self.state.modify(http1: { http1 in - http1.timeoutRequest(requestID) - }, http2: { http2 in - http2.timeoutRequest(requestID) - }) + self.state.modify( + http1: { http1 in + http1.timeoutRequest(requestID) + }, + http2: { http2 in + http2.timeoutRequest(requestID) + } + ) } /// A request was cancelled. @@ -262,44 +304,59 @@ extension HTTPConnectionPool { /// need to cancel its timeout timer. If a request times out, we need to fail the request, but don't /// need to cancel the timer (it already triggered). mutating func cancelRequest(_ requestID: Request.ID) -> Action { - self.state.modify(http1: { http1 in - http1.cancelRequest(requestID) - }, http2: { http2 in - http2.cancelRequest(requestID) - }) + self.state.modify( + http1: { http1 in + http1.cancelRequest(requestID) + }, + http2: { http2 in + http2.cancelRequest(requestID) + } + ) } mutating func connectionIdleTimeout(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.connectionIdleTimeout(connectionID) - }, http2: { http2 in - http2.connectionIdleTimeout(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.connectionIdleTimeout(connectionID) + }, + http2: { http2 in + http2.connectionIdleTimeout(connectionID) + } + ) } /// A connection has been closed mutating func http1ConnectionClosed(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http1ConnectionClosed(connectionID) - }, http2: { http2 in - http2.http1ConnectionClosed(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http1ConnectionClosed(connectionID) + }, + http2: { http2 in + http2.http1ConnectionClosed(connectionID) + } + ) } mutating func http1ConnectionReleased(_ connectionID: Connection.ID) -> Action { - self.state.modify(http1: { http1 in - http1.http1ConnectionReleased(connectionID) - }, http2: { http2 in - http2.http1ConnectionReleased(connectionID) - }) + self.state.modify( + http1: { http1 in + http1.http1ConnectionReleased(connectionID) + }, + http2: { http2 in + http2.http1ConnectionReleased(connectionID) + } + ) } mutating func shutdown() -> Action { - return self.state.modify(http1: { http1 in - http1.shutdown() - }, http2: { http2 in - http2.shutdown() - }) + self.state.modify( + http1: { http1 in + http1.shutdown() + }, + http2: { http2 in + http2.shutdown() + } + ) } } } @@ -342,7 +399,9 @@ extension HTTPConnectionPool.StateMachine { } struct EstablishedAction { - static let none: Self = .init(request: .none, connection: .none) + static var none: Self { + Self(request: .none, connection: .none) + } let request: HTTPConnectionPool.StateMachine.RequestAction let connection: EstablishedConnectionAction } @@ -350,7 +409,10 @@ extension HTTPConnectionPool.StateMachine { enum EstablishedConnectionAction { case none case scheduleTimeoutTimer(HTTPConnectionPool.Connection.ID, on: EventLoop) - case closeConnection(HTTPConnectionPool.Connection, isShutdown: HTTPConnectionPool.StateMachine.ConnectionAction.IsShutdown) + case closeConnection( + HTTPConnectionPool.Connection, + isShutdown: HTTPConnectionPool.StateMachine.ConnectionAction.IsShutdown + ) } } @@ -391,8 +453,7 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction { case .closeConnection(let connection, let isShutdown): guard isShutdown == .no else { precondition( - migrationAction.closeConnections.isEmpty && - migrationAction.createConnections.isEmpty, + migrationAction.closeConnections.isEmpty && migrationAction.createConnections.isEmpty, "migration actions are not supported during shutdown" ) return .closeConnection(connection, isShutdown: isShutdown) diff --git a/Sources/AsyncHTTPClient/DeconstructedURL.swift b/Sources/AsyncHTTPClient/DeconstructedURL.swift index 020c17455..52042bce3 100644 --- a/Sources/AsyncHTTPClient/DeconstructedURL.swift +++ b/Sources/AsyncHTTPClient/DeconstructedURL.swift @@ -48,9 +48,16 @@ extension DeconstructedURL { switch scheme { case .http, .https: + #if !canImport(Darwin) && compiler(>=6.0) + guard let urlHost = url.host, !urlHost.isEmpty else { + throw HTTPClientError.emptyHost + } + let host = urlHost.trimIPv6Brackets() + #else guard let host = url.host, !host.isEmpty else { throw HTTPClientError.emptyHost } + #endif self.init( scheme: scheme, connectionTarget: .init(remoteHost: host, port: url.port ?? scheme.defaultPort), @@ -81,3 +88,26 @@ extension DeconstructedURL { } } } + +#if !canImport(Darwin) && compiler(>=6.0) +extension String { + @inlinable internal func trimIPv6Brackets() -> String { + var utf8View = self.utf8[...] + + var modified = false + if utf8View.first == UInt8(ascii: "[") { + utf8View = utf8View.dropFirst() + modified = true + } + if utf8View.last == UInt8(ascii: "]") { + utf8View = utf8View.dropLast() + modified = true + } + + if modified { + return String(Substring(utf8View)) + } + return self + } +} +#endif diff --git a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift index 9a351f3c1..33a4d3cb2 100644 --- a/Sources/AsyncHTTPClient/FileDownloadDelegate.swift +++ b/Sources/AsyncHTTPClient/FileDownloadDelegate.swift @@ -12,30 +12,68 @@ // //===----------------------------------------------------------------------===// +import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import NIOPosix +import struct Foundation.URL + /// Handles a streaming download to a given file path, allowing headers and progress to be reported. public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// The response type for this delegate: the total count of bytes as reported by the response - /// "Content-Length" header (if available) and the count of bytes downloaded. + /// "Content-Length" header (if available), the count of bytes downloaded, the + /// response head, and a history of requests and responses. public struct Progress: Sendable { public var totalBytes: Int? public var receivedBytes: Int + + /// The history of all requests and responses in redirect order. + public var history: [HTTPClient.RequestResponse] = [] + + /// The target URL (after redirects) of the response. + public var url: URL? { + self.history.last?.request.url + } + + public var head: HTTPResponseHead { + get { + assert(self._head != nil) + return self._head! + } + set { + self._head = newValue + } + } + + fileprivate var _head: HTTPResponseHead? = nil + + internal init(totalBytes: Int? = nil, receivedBytes: Int) { + self.totalBytes = totalBytes + self.receivedBytes = receivedBytes + } } - private var progress = Progress(totalBytes: nil, receivedBytes: 0) + private struct State { + var progress = Progress( + totalBytes: nil, + receivedBytes: 0 + ) + var fileIOThreadPool: NIOThreadPool? + var fileHandleFuture: EventLoopFuture? + var writeFuture: EventLoopFuture? + } + private let state: NIOLockedValueBox + + var _fileIOThreadPool: NIOThreadPool? { + self.state.withLockedValue { $0.fileIOThreadPool } + } public typealias Response = Progress private let filePath: String - private(set) var fileIOThreadPool: NIOThreadPool? - private let reportHead: ((HTTPClient.Task, HTTPResponseHead) -> Void)? - private let reportProgress: ((HTTPClient.Task, Progress) -> Void)? - - private var fileHandleFuture: EventLoopFuture? - private var writeFuture: EventLoopFuture? + private let reportHead: (@Sendable (HTTPClient.Task, HTTPResponseHead) -> Void)? + private let reportProgress: (@Sendable (HTTPClient.Task, Progress) -> Void)? /// Initializes a new file download delegate. /// @@ -47,20 +85,14 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public init( path: String, pool: NIOThreadPool? = nil, - reportHead: ((HTTPClient.Task, HTTPResponseHead) -> Void)? = nil, - reportProgress: ((HTTPClient.Task, Progress) -> Void)? = nil + reportHead: (@Sendable (HTTPClient.Task, HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (HTTPClient.Task, Progress) -> Void)? = nil ) throws { - if let pool = pool { - self.fileIOThreadPool = pool - } else { - // we should use the shared thread pool from the HTTPClient which - // we will get from the `HTTPClient.Task` - self.fileIOThreadPool = nil - } - + self.state = NIOLockedValueBox(State(fileIOThreadPool: pool)) self.filePath = path self.reportHead = reportHead @@ -77,22 +109,23 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public convenience init( path: String, pool: NIOThreadPool, - reportHead: ((HTTPResponseHead) -> Void)? = nil, - reportProgress: ((Progress) -> Void)? = nil + reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (Progress) -> Void)? = nil ) throws { try self.init( path: path, pool: .some(pool), reportHead: reportHead.map { reportHead in - return { _, head in + { @Sendable _, head in reportHead(head) } }, reportProgress: reportProgress.map { reportProgress in - return { _, head in + { @Sendable _, head in reportProgress(head) } } @@ -108,38 +141,50 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { /// the total byte count and download byte count passed to it as arguments. The callbacks /// will be invoked in the same threading context that the delegate itself is invoked, /// as controlled by `EventLoopPreference`. + @preconcurrency public convenience init( path: String, - reportHead: ((HTTPResponseHead) -> Void)? = nil, - reportProgress: ((Progress) -> Void)? = nil + reportHead: (@Sendable (HTTPResponseHead) -> Void)? = nil, + reportProgress: (@Sendable (Progress) -> Void)? = nil ) throws { try self.init( path: path, pool: nil, reportHead: reportHead.map { reportHead in - return { _, head in + { @Sendable _, head in reportHead(head) } }, reportProgress: reportProgress.map { reportProgress in - return { _, head in + { @Sendable _, head in reportProgress(head) } } ) } + public func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) { + self.state.withLockedValue { + $0.progress.history.append(.init(request: request, responseHead: head)) + } + } + public func didReceiveHead( task: HTTPClient.Task, _ head: HTTPResponseHead ) -> EventLoopFuture { - self.reportHead?(task, head) + self.state.withLockedValue { + $0.progress._head = head - if let totalBytesString = head.headers.first(name: "Content-Length"), - let totalBytes = Int(totalBytesString) { - self.progress.totalBytes = totalBytes + if let totalBytesString = head.headers.first(name: "Content-Length"), + let totalBytes = Int(totalBytesString) + { + $0.progress.totalBytes = totalBytes + } } + self.reportHead?(task, head) + return task.eventLoop.makeSucceededFuture(()) } @@ -147,53 +192,90 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { task: HTTPClient.Task, _ buffer: ByteBuffer ) -> EventLoopFuture { - let threadPool: NIOThreadPool = { - guard let pool = self.fileIOThreadPool else { - let pool = task.fileIOThreadPool - self.fileIOThreadPool = pool + let (progress, io) = self.state.withLockedValue { state in + let threadPool: NIOThreadPool = { + guard let pool = state.fileIOThreadPool else { + let pool = task.fileIOThreadPool + state.fileIOThreadPool = pool + return pool + } return pool + }() + + let io = NonBlockingFileIO(threadPool: threadPool) + state.progress.receivedBytes += buffer.readableBytes + return (state.progress, io) + } + self.reportProgress?(task, progress) + + let writeFuture = self.state.withLockedValue { state in + let writeFuture: EventLoopFuture + if let fileHandleFuture = state.fileHandleFuture { + writeFuture = fileHandleFuture.flatMap { + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + } + } else { + let fileHandleFuture = io.openFile( + _deprecatedPath: self.filePath, + mode: .write, + flags: .allowFileCreation(), + eventLoop: task.eventLoop + ) + state.fileHandleFuture = fileHandleFuture + writeFuture = fileHandleFuture.flatMap { + io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) + } } - return pool - }() - let io = NonBlockingFileIO(threadPool: threadPool) - self.progress.receivedBytes += buffer.readableBytes - self.reportProgress?(task, self.progress) - - let writeFuture: EventLoopFuture - if let fileHandleFuture = self.fileHandleFuture { - writeFuture = fileHandleFuture.flatMap { - io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) - } - } else { - let fileHandleFuture = io.openFile( - path: self.filePath, - mode: .write, - flags: .allowFileCreation(), - eventLoop: task.eventLoop - ) - self.fileHandleFuture = fileHandleFuture - writeFuture = fileHandleFuture.flatMap { - io.write(fileHandle: $0, buffer: buffer, eventLoop: task.eventLoop) - } + + state.writeFuture = writeFuture + return writeFuture } - self.writeFuture = writeFuture return writeFuture } private func close(fileHandle: NIOFileHandle) { try! fileHandle.close() - self.fileHandleFuture = nil + self.state.withLockedValue { + $0.fileHandleFuture = nil + } } private func finalize() { - if let writeFuture = self.writeFuture { - writeFuture.whenComplete { _ in - self.fileHandleFuture?.whenSuccess(self.close(fileHandle:)) - self.writeFuture = nil + enum Finalize { + case writeFuture(EventLoopFuture) + case fileHandleFuture(EventLoopFuture) + case none + } + + let finalize: Finalize = self.state.withLockedValue { state in + if let writeFuture = state.writeFuture { + return .writeFuture(writeFuture) + } else if let fileHandleFuture = state.fileHandleFuture { + return .fileHandleFuture(fileHandleFuture) + } else { + return .none + } + } + + switch finalize { + case .writeFuture(let future): + future.whenComplete { _ in + let fileHandleFuture = self.state.withLockedValue { state in + let future = state.fileHandleFuture + state.fileHandleFuture = nil + state.writeFuture = nil + return future + } + + fileHandleFuture?.whenSuccess { + self.close(fileHandle: $0) + } } - } else { - self.fileHandleFuture?.whenSuccess(self.close(fileHandle:)) + case .fileHandleFuture(let future): + future.whenSuccess { self.close(fileHandle: $0) } + case .none: + () } } @@ -203,6 +285,6 @@ public final class FileDownloadDelegate: HTTPClientResponseDelegate { public func didFinishRequest(task: HTTPClient.Task) throws -> Response { self.finalize() - return self.progress + return self.state.withLockedValue { $0.progress } } } diff --git a/Sources/AsyncHTTPClient/FoundationExtensions.swift b/Sources/AsyncHTTPClient/FoundationExtensions.swift index 545da756b..452cb7b13 100644 --- a/Sources/AsyncHTTPClient/FoundationExtensions.swift +++ b/Sources/AsyncHTTPClient/FoundationExtensions.swift @@ -39,7 +39,16 @@ extension HTTPClient.Cookie { /// - maxAge: The cookie's age in seconds, defaults to nil. /// - httpOnly: Whether this cookie should be used by HTTP servers only, defaults to false. /// - secure: Whether this cookie should only be sent using secure channels, defaults to false. - public init(name: String, value: String, path: String = "/", domain: String? = nil, expires: Date? = nil, maxAge: Int? = nil, httpOnly: Bool = false, secure: Bool = false) { + public init( + name: String, + value: String, + path: String = "/", + domain: String? = nil, + expires: Date? = nil, + maxAge: Int? = nil, + httpOnly: Bool = false, + secure: Bool = false + ) { // FIXME: This should be failable and validate the inputs // (for example, checking that the strings are ASCII, path begins with "/", domain is not empty, etc). self.init( @@ -59,8 +68,8 @@ extension HTTPClient.Body { /// Create and stream body using `Data`. /// /// - parameters: - /// - bytes: Body `Data` representation. + /// - data: Body `Data` representation. public static func data(_ data: Data) -> HTTPClient.Body { - return self.bytes(data) + self.bytes(data) } } diff --git a/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift b/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift index b2e9d7b05..759f6728a 100644 --- a/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift +++ b/Sources/AsyncHTTPClient/HTTPClient+HTTPCookie.swift @@ -12,16 +12,25 @@ // //===----------------------------------------------------------------------===// +import CAsyncHTTPClient +import NIOCore import NIOHTTP1 + +#if canImport(xlocale) +import xlocale +#elseif canImport(locale_h) +import locale_h +#endif + #if canImport(Darwin) import Darwin #elseif canImport(Musl) import Musl +#elseif canImport(Android) +import Android #elseif canImport(Glibc) import Glibc #endif -import CAsyncHTTPClient -import NIOCore extension HTTPClient { /// A representation of an HTTP cookie. @@ -48,7 +57,6 @@ extension HTTPClient { /// - parameters: /// - header: String representation of the `Set-Cookie` response header. /// - defaultDomain: Default domain to use if cookie was sent without one. - /// - returns: nil if the header is invalid. public init?(header: String, defaultDomain: String) { // The parsing of "Set-Cookie" headers is defined by Section 5.2, RFC-6265: // https://datatracker.ietf.org/doc/html/rfc6265#section-5.2 @@ -129,7 +137,16 @@ extension HTTPClient { /// - maxAge: The cookie's age in seconds, defaults to nil. /// - httpOnly: Whether this cookie should be used by HTTP servers only, defaults to false. /// - secure: Whether this cookie should only be sent using secure channels, defaults to false. - internal init(name: String, value: String, path: String = "/", domain: String? = nil, expires_timestamp: Int64? = nil, maxAge: Int? = nil, httpOnly: Bool = false, secure: Bool = false) { + internal init( + name: String, + value: String, + path: String = "/", + domain: String? = nil, + expires_timestamp: Int64? = nil, + maxAge: Int? = nil, + httpOnly: Bool = false, + secure: Bool = false + ) { self.name = name self.value = value self.path = path @@ -145,7 +162,7 @@ extension HTTPClient { extension HTTPClient.Response { /// List of HTTP cookies returned by the server. public var cookies: [HTTPClient.Cookie] { - return self.headers["set-cookie"].compactMap { HTTPClient.Cookie(header: $0, defaultDomain: self.host) } + self.headers["set-cookie"].compactMap { HTTPClient.Cookie(header: $0, defaultDomain: self.host) } } } @@ -199,7 +216,7 @@ extension String.UTF8View.SubSequence { } } -private let posixLocale: UnsafeMutableRawPointer = { +nonisolated(unsafe) private let posixLocale: UnsafeMutableRawPointer = { // All POSIX systems must provide a "POSIX" locale, and its date/time formats are US English. // https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/V1_chap07.html#tag_07_03_05 let _posixLocale = newlocale(LC_TIME_MASK | LC_NUMERIC_MASK, "POSIX", nil)! @@ -215,7 +232,8 @@ private func parseTimestamp(_ utf8: String.UTF8View.SubSequence, format: String) } private func parseCookieTime(_ timestampUTF8: String.UTF8View.SubSequence) -> Int64? { - if timestampUTF8.contains(where: { $0 < 0x20 /* Control characters */ || $0 == 0x7F /* DEL */ }) { + // 0x20: Control characters or 0x7F: DEL + if timestampUTF8.contains(where: { $0 < 0x20 || $0 == 0x7F }) { return nil } var timestampUTF8 = timestampUTF8 @@ -228,8 +246,8 @@ private func parseCookieTime(_ timestampUTF8: String.UTF8View.SubSequence) -> In } guard var timeComponents = parseTimestamp(timestampUTF8, format: "%a, %d %b %Y %H:%M:%S") - ?? parseTimestamp(timestampUTF8, format: "%a, %d-%b-%y %H:%M:%S") - ?? parseTimestamp(timestampUTF8, format: "%a %b %d %H:%M:%S %Y") + ?? parseTimestamp(timestampUTF8, format: "%a, %d-%b-%y %H:%M:%S") + ?? parseTimestamp(timestampUTF8, format: "%a %b %d %H:%M:%S %Y") else { return nil } diff --git a/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift b/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift index 25b4b4555..e95c828ce 100644 --- a/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift +++ b/Sources/AsyncHTTPClient/HTTPClient+Proxy.swift @@ -38,7 +38,10 @@ extension HTTPClient.Configuration { /// Specifies Proxy server authorization. public var authorization: HTTPClient.Authorization? { set { - precondition(self.type == .http(self.authorization), "SOCKS authorization support is not yet implemented.") + precondition( + self.type == .http(self.authorization), + "SOCKS authorization support is not yet implemented." + ) self.type = .http(newValue) } @@ -60,7 +63,7 @@ extension HTTPClient.Configuration { /// - host: proxy server host. /// - port: proxy server port. public static func server(host: String, port: Int) -> Proxy { - return .init(host: host, port: port, type: .http(nil)) + .init(host: host, port: port, type: .http(nil)) } /// Create a HTTP proxy. @@ -70,7 +73,7 @@ extension HTTPClient.Configuration { /// - port: proxy server port. /// - authorization: proxy server authorization. public static func server(host: String, port: Int, authorization: HTTPClient.Authorization? = nil) -> Proxy { - return .init(host: host, port: port, type: .http(authorization)) + .init(host: host, port: port, type: .http(authorization)) } /// Create a SOCKSv5 proxy. @@ -78,7 +81,7 @@ extension HTTPClient.Configuration { /// - parameter port: The SOCKSv5 proxy port, defaults to 1080. /// - returns: A new instance of `Proxy` configured to connect to a `SOCKSv5` server. public static func socksServer(host: String, port: Int = 1080) -> Proxy { - return .init(host: host, port: port, type: .socks) + .init(host: host, port: port, type: .socks) } } } diff --git a/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift b/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift new file mode 100644 index 000000000..f7d471f10 --- /dev/null +++ b/Sources/AsyncHTTPClient/HTTPClient+StructuredConcurrency.swift @@ -0,0 +1,72 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import Logging +import NIO + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +extension HTTPClient { + #if compiler(>=6.0) + /// Start & automatically shut down a new ``HTTPClient``. + /// + /// This method allows to start & automatically dispose of a ``HTTPClient`` following the principle of Structured Concurrency. + /// The ``HTTPClient`` is guaranteed to be shut down upon return, whether `body` throws or not. + /// + /// This may be particularly useful if you cannot use the shared singleton (``HTTPClient/shared``). + public static func withHTTPClient( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger? = nil, + isolation: isolated (any Actor)? = #isolation, + _ body: (HTTPClient) async throws -> Return + ) async throws -> Return { + let logger = (backgroundActivityLogger ?? HTTPClient.loggingDisabled) + let httpClient = HTTPClient( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: logger + ) + return try await asyncDo { + try await body(httpClient) + } finally: { _ in + try await httpClient.shutdown() + } + } + #else + /// Start & automatically shut down a new ``HTTPClient``. + /// + /// This method allows to start & automatically dispose of a ``HTTPClient`` following the principle of Structured Concurrency. + /// The ``HTTPClient`` is guaranteed to be shut down upon return, whether `body` throws or not. + /// + /// This may be particularly useful if you cannot use the shared singleton (``HTTPClient/shared``). + public static func withHTTPClient( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger? = nil, + _ body: (HTTPClient) async throws -> Return + ) async throws -> Return { + let logger = (backgroundActivityLogger ?? HTTPClient.loggingDisabled) + let httpClient = HTTPClient( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: logger + ) + return try await asyncDo { + try await body(httpClient) + } finally: { _ in + try await httpClient.shutdown() + } + } + #endif +} diff --git a/Sources/AsyncHTTPClient/HTTPClient.swift b/Sources/AsyncHTTPClient/HTTPClient.swift index 6c5a9af20..e628c6073 100644 --- a/Sources/AsyncHTTPClient/HTTPClient.swift +++ b/Sources/AsyncHTTPClient/HTTPClient.swift @@ -26,7 +26,7 @@ import NIOTransportServices extension Logger { private func requestInfo(_ request: HTTPClient.Request) -> Logger.Metadata.Value { - return "\(request.method) \(request.url)" + "\(request.method) \(request.url)" } func attachingRequestInformation(_ request: HTTPClient.Request, requestID: Int) -> Logger { @@ -80,23 +80,31 @@ public class HTTPClient { /// - parameters: /// - eventLoopGroupProvider: Specify how `EventLoopGroup` will be created. /// - configuration: Client configuration. - public convenience init(eventLoopGroupProvider: EventLoopGroupProvider, - configuration: Configuration = Configuration()) { - self.init(eventLoopGroupProvider: eventLoopGroupProvider, - configuration: configuration, - backgroundActivityLogger: HTTPClient.loggingDisabled) + public convenience init( + eventLoopGroupProvider: EventLoopGroupProvider, + configuration: Configuration = Configuration() + ) { + self.init( + eventLoopGroupProvider: eventLoopGroupProvider, + configuration: configuration, + backgroundActivityLogger: HTTPClient.loggingDisabled + ) } /// Create an ``HTTPClient`` with specified `EventLoopGroup` and configuration. /// /// - parameters: - /// - eventLoopGroupProvider: Specify how `EventLoopGroup` will be created. + /// - eventLoopGroup: Specify how `EventLoopGroup` will be created. /// - configuration: Client configuration. - public convenience init(eventLoopGroup: EventLoopGroup = HTTPClient.defaultEventLoopGroup, - configuration: Configuration = Configuration()) { - self.init(eventLoopGroupProvider: .shared(eventLoopGroup), - configuration: configuration, - backgroundActivityLogger: HTTPClient.loggingDisabled) + public convenience init( + eventLoopGroup: EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration() + ) { + self.init( + eventLoopGroupProvider: .shared(eventLoopGroup), + configuration: configuration, + backgroundActivityLogger: HTTPClient.loggingDisabled + ) } /// Create an ``HTTPClient`` with specified `EventLoopGroup` provider and configuration. @@ -104,21 +112,26 @@ public class HTTPClient { /// - parameters: /// - eventLoopGroupProvider: Specify how `EventLoopGroup` will be created. /// - configuration: Client configuration. - public convenience init(eventLoopGroupProvider: EventLoopGroupProvider, - configuration: Configuration = Configuration(), - backgroundActivityLogger: Logger) { + /// - backgroundActivityLogger: The logger to use for background activity logs. + public convenience init( + eventLoopGroupProvider: EventLoopGroupProvider, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger + ) { let eventLoopGroup: any EventLoopGroup switch eventLoopGroupProvider { case .shared(let group): eventLoopGroup = group - default: // handle `.createNew` without a deprecation warning + default: // handle `.createNew` without a deprecation warning eventLoopGroup = HTTPClient.defaultEventLoopGroup } - self.init(eventLoopGroup: eventLoopGroup, - configuration: configuration, - backgroundActivityLogger: backgroundActivityLogger) + self.init( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: backgroundActivityLogger + ) } /// Create an ``HTTPClient`` with specified `EventLoopGroup` and configuration. @@ -127,19 +140,25 @@ public class HTTPClient { /// - eventLoopGroup: The `EventLoopGroup` that the ``HTTPClient`` will use. /// - configuration: Client configuration. /// - backgroundActivityLogger: The `Logger` that will be used to log background any activity that's not associated with a request. - public convenience init(eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, - configuration: Configuration = Configuration(), - backgroundActivityLogger: Logger) { - self.init(eventLoopGroup: eventLoopGroup, - configuration: configuration, - backgroundActivityLogger: backgroundActivityLogger, - canBeShutDown: true) + public convenience init( + eventLoopGroup: any EventLoopGroup = HTTPClient.defaultEventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger + ) { + self.init( + eventLoopGroup: eventLoopGroup, + configuration: configuration, + backgroundActivityLogger: backgroundActivityLogger, + canBeShutDown: true + ) } - internal required init(eventLoopGroup: EventLoopGroup, - configuration: Configuration = Configuration(), - backgroundActivityLogger: Logger, - canBeShutDown: Bool) { + internal required init( + eventLoopGroup: EventLoopGroup, + configuration: Configuration = Configuration(), + backgroundActivityLogger: Logger, + canBeShutDown: Bool + ) { self.canBeShutDown = canBeShutDown self.eventLoopGroup = eventLoopGroup self.configuration = configuration @@ -158,15 +177,19 @@ public class HTTPClient { case .shutDown: break case .shuttingDown: - preconditionFailure(""" - This state should be totally unreachable. While the HTTPClient is shutting down a \ - reference cycle should exist, that prevents it from deinit. - """) + preconditionFailure( + """ + This state should be totally unreachable. While the HTTPClient is shutting down a \ + reference cycle should exist, that prevents it from deinit. + """ + ) case .upAndRunning: - preconditionFailure(""" - Client not shut down before the deinit. Please call client.shutdown() when no \ - longer needed. Otherwise memory will leak. - """) + preconditionFailure( + """ + Client not shut down before the deinit. Please call client.shutdown() when no \ + longer needed. Otherwise memory will leak. + """ + ) } } } @@ -191,29 +214,58 @@ public class HTTPClient { /// In general, setting this parameter to `true` should make it easier and faster to catch related programming errors. func syncShutdown(requiresCleanClose: Bool) throws { if let eventLoop = MultiThreadedEventLoopGroup.currentEventLoop { - preconditionFailure(""" - BUG DETECTED: syncShutdown() must not be called when on an EventLoop. - Calling syncShutdown() on any EventLoop can lead to deadlocks. - Current eventLoop: \(eventLoop) - """) + preconditionFailure( + """ + BUG DETECTED: syncShutdown() must not be called when on an EventLoop. + Calling syncShutdown() on any EventLoop can lead to deadlocks. + Current eventLoop: \(eventLoop) + """ + ) } - let errorStorageLock = NIOLock() - let errorStorage: UnsafeMutableTransferBox = .init(nil) - let continuation = DispatchWorkItem {} - self.shutdown(requiresCleanClose: requiresCleanClose, queue: DispatchQueue(label: "async-http-client.shutdown")) { error in - if let error = error { - errorStorageLock.withLock { - errorStorage.wrappedValue = error + + final class ShutdownError: @unchecked Sendable { + // @unchecked because error is protected by lock. + + // Stores whether the shutdown has happened or not. + private let lock: ConditionLock + private var error: Error? + + init() { + self.error = nil + self.lock = ConditionLock(value: false) + } + + func didShutdown(_ error: (any Error)?) { + self.lock.lock(whenValue: false) + defer { + self.lock.unlock(withValue: true) } + self.error = error } - continuation.perform() - } - continuation.wait() - try errorStorageLock.withLock { - if let error = errorStorage.wrappedValue { - throw error + + func blockUntilShutdown() -> (any Error)? { + self.lock.lock(whenValue: true) + defer { + self.lock.unlock(withValue: true) + } + return self.error } } + + let shutdownError = ShutdownError() + + self.shutdown( + requiresCleanClose: requiresCleanClose, + queue: DispatchQueue(label: "async-http-client.shutdown") + ) { error in + shutdownError.didShutdown(error) + } + + let error = shutdownError.blockUntilShutdown() + + if let error = error { + throw error + } } /// Shuts down the client and event loop gracefully. @@ -286,6 +338,7 @@ public class HTTPClient { } } + @Sendable private func makeOrGetFileIOThreadPool() -> NIOThreadPool { self.fileIOThreadPoolLock.withLock { guard let fileIOThreadPool = self.fileIOThreadPool else { @@ -301,7 +354,7 @@ public class HTTPClient { /// - url: Remote URL. /// - deadline: Point in time by which the request must complete. public func get(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.get(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.get(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `GET` request using specified URL. @@ -311,7 +364,7 @@ public class HTTPClient { /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. public func get(url: String, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.GET, url: url, deadline: deadline, logger: logger) + self.execute(.GET, url: url, deadline: deadline, logger: logger) } /// Execute `POST` request using specified URL. @@ -321,7 +374,7 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. public func post(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.post(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.post(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `POST` request using specified URL. @@ -331,8 +384,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func post(url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.POST, url: url, body: body, deadline: deadline, logger: logger) + public func post( + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger + ) -> EventLoopFuture { + self.execute(.POST, url: url, body: body, deadline: deadline, logger: logger) } /// Execute `PATCH` request using specified URL. @@ -342,7 +400,7 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. public func patch(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.patch(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.patch(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `PATCH` request using specified URL. @@ -352,8 +410,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func patch(url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.PATCH, url: url, body: body, deadline: deadline, logger: logger) + public func patch( + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger + ) -> EventLoopFuture { + self.execute(.PATCH, url: url, body: body, deadline: deadline, logger: logger) } /// Execute `PUT` request using specified URL. @@ -363,7 +426,7 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. public func put(url: String, body: Body? = nil, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.put(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.put(url: url, body: body, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `PUT` request using specified URL. @@ -373,8 +436,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func put(url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.PUT, url: url, body: body, deadline: deadline, logger: logger) + public func put( + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger + ) -> EventLoopFuture { + self.execute(.PUT, url: url, body: body, deadline: deadline, logger: logger) } /// Execute `DELETE` request using specified URL. @@ -383,7 +451,7 @@ public class HTTPClient { /// - url: Remote URL. /// - deadline: The time when the request must have been completed by. public func delete(url: String, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.delete(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.delete(url: url, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute `DELETE` request using specified URL. @@ -393,7 +461,7 @@ public class HTTPClient { /// - deadline: The time when the request must have been completed by. /// - logger: The logger to use for this request. public func delete(url: String, deadline: NIODeadline? = nil, logger: Logger) -> EventLoopFuture { - return self.execute(.DELETE, url: url, deadline: deadline, logger: logger) + self.execute(.DELETE, url: url, deadline: deadline, logger: logger) } /// Execute arbitrary HTTP request using specified URL. @@ -404,7 +472,13 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(_ method: HTTPMethod = .GET, url: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger? = nil) -> EventLoopFuture { + public func execute( + _ method: HTTPMethod = .GET, + url: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger? = nil + ) -> EventLoopFuture { do { let request = try Request(url: url, method: method, body: body) return self.execute(request: request, deadline: deadline, logger: logger ?? HTTPClient.loggingDisabled) @@ -422,7 +496,14 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(_ method: HTTPMethod = .GET, socketPath: String, urlPath: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger? = nil) -> EventLoopFuture { + public func execute( + _ method: HTTPMethod = .GET, + socketPath: String, + urlPath: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger? = nil + ) -> EventLoopFuture { do { guard let url = URL(httpURLWithSocketPath: socketPath, uri: urlPath) else { throw HTTPClientError.invalidURL @@ -443,7 +524,14 @@ public class HTTPClient { /// - body: Request body. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(_ method: HTTPMethod = .GET, secureSocketPath: String, urlPath: String, body: Body? = nil, deadline: NIODeadline? = nil, logger: Logger? = nil) -> EventLoopFuture { + public func execute( + _ method: HTTPMethod = .GET, + secureSocketPath: String, + urlPath: String, + body: Body? = nil, + deadline: NIODeadline? = nil, + logger: Logger? = nil + ) -> EventLoopFuture { do { guard let url = URL(httpsURLWithSocketPath: secureSocketPath, uri: urlPath) else { throw HTTPClientError.invalidURL @@ -461,7 +549,7 @@ public class HTTPClient { /// - request: HTTP request to execute. /// - deadline: Point in time by which the request must complete. public func execute(request: Request, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.execute(request: request, deadline: deadline, logger: HTTPClient.loggingDisabled) + self.execute(request: request, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute arbitrary HTTP request using specified URL. @@ -481,26 +569,40 @@ public class HTTPClient { /// - request: HTTP request to execute. /// - eventLoop: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. - public func execute(request: Request, eventLoop: EventLoopPreference, deadline: NIODeadline? = nil) -> EventLoopFuture { - return self.execute(request: request, - eventLoop: eventLoop, - deadline: deadline, - logger: HTTPClient.loggingDisabled) + public func execute( + request: Request, + eventLoop: EventLoopPreference, + deadline: NIODeadline? = nil + ) -> EventLoopFuture { + self.execute( + request: request, + eventLoop: eventLoop, + deadline: deadline, + logger: HTTPClient.loggingDisabled + ) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. /// /// - parameters: /// - request: HTTP request to execute. - /// - eventLoop: NIO Event Loop preference. + /// - eventLoopPreference: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(request: Request, - eventLoop eventLoopPreference: EventLoopPreference, - deadline: NIODeadline? = nil, - logger: Logger?) -> EventLoopFuture { + public func execute( + request: Request, + eventLoop eventLoopPreference: EventLoopPreference, + deadline: NIODeadline? = nil, + logger: Logger? + ) -> EventLoopFuture { let accumulator = ResponseAccumulator(request: request) - return self.execute(request: request, delegate: accumulator, eventLoop: eventLoopPreference, deadline: deadline, logger: logger).futureResult + return self.execute( + request: request, + delegate: accumulator, + eventLoop: eventLoopPreference, + deadline: deadline, + logger: logger + ).futureResult } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -509,10 +611,12 @@ public class HTTPClient { /// - request: HTTP request to execute. /// - delegate: Delegate to process response parts. /// - deadline: Point in time by which the request must complete. - public func execute(request: Request, - delegate: Delegate, - deadline: NIODeadline? = nil) -> Task { - return self.execute(request: request, delegate: delegate, deadline: deadline, logger: HTTPClient.loggingDisabled) + public func execute( + request: Request, + delegate: Delegate, + deadline: NIODeadline? = nil + ) -> Task { + self.execute(request: request, delegate: delegate, deadline: deadline, logger: HTTPClient.loggingDisabled) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -522,11 +626,13 @@ public class HTTPClient { /// - delegate: Delegate to process response parts. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. - public func execute(request: Request, - delegate: Delegate, - deadline: NIODeadline? = nil, - logger: Logger) -> Task { - return self.execute(request: request, delegate: delegate, eventLoop: .indifferent, deadline: deadline, logger: logger) + public func execute( + request: Request, + delegate: Delegate, + deadline: NIODeadline? = nil, + logger: Logger + ) -> Task { + self.execute(request: request, delegate: delegate, eventLoop: .indifferent, deadline: deadline, logger: logger) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -534,18 +640,21 @@ public class HTTPClient { /// - parameters: /// - request: HTTP request to execute. /// - delegate: Delegate to process response parts. - /// - eventLoop: NIO Event Loop preference. + /// - eventLoopPreference: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. - /// - logger: The logger to use for this request. - public func execute(request: Request, - delegate: Delegate, - eventLoop eventLoopPreference: EventLoopPreference, - deadline: NIODeadline? = nil) -> Task { - return self.execute(request: request, - delegate: delegate, - eventLoop: eventLoopPreference, - deadline: deadline, - logger: HTTPClient.loggingDisabled) + public func execute( + request: Request, + delegate: Delegate, + eventLoop eventLoopPreference: EventLoopPreference, + deadline: NIODeadline? = nil + ) -> Task { + self.execute( + request: request, + delegate: delegate, + eventLoop: eventLoopPreference, + deadline: deadline, + logger: HTTPClient.loggingDisabled + ) } /// Execute arbitrary HTTP request and handle response processing using provided delegate. @@ -553,7 +662,7 @@ public class HTTPClient { /// - parameters: /// - request: HTTP request to execute. /// - delegate: Delegate to process response parts. - /// - eventLoop: NIO Event Loop preference. + /// - eventLoopPreference: NIO Event Loop preference. /// - deadline: Point in time by which the request must complete. /// - logger: The logger to use for this request. public func execute( @@ -561,14 +670,14 @@ public class HTTPClient { delegate: Delegate, eventLoop eventLoopPreference: EventLoopPreference, deadline: NIODeadline? = nil, - logger originalLogger: Logger? + logger: Logger? ) -> Task { self._execute( request: request, delegate: delegate, eventLoop: eventLoopPreference, deadline: deadline, - logger: originalLogger, + logger: logger, redirectState: RedirectState( self.configuration.redirectConfiguration.mode, initialURL: request.url.absoluteString @@ -592,25 +701,38 @@ public class HTTPClient { logger originalLogger: Logger?, redirectState: RedirectState? ) -> Task { - let logger = (originalLogger ?? HTTPClient.loggingDisabled).attachingRequestInformation(request, requestID: globalRequestID.wrappingIncrementThenLoad(ordering: .relaxed)) + let logger = (originalLogger ?? HTTPClient.loggingDisabled).attachingRequestInformation( + request, + requestID: globalRequestID.wrappingIncrementThenLoad(ordering: .relaxed) + ) let taskEL: EventLoop switch eventLoopPreference.preference { case .indifferent: // if possible we want a connection on the current `EventLoop` taskEL = self.eventLoopGroup.any() case .delegate(on: let eventLoop): - precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") + precondition( + self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, + "Provided EventLoop must be part of clients EventLoopGroup." + ) taskEL = eventLoop case .delegateAndChannel(on: let eventLoop): - precondition(self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, "Provided EventLoop must be part of clients EventLoopGroup.") + precondition( + self.eventLoopGroup.makeIterator().contains { $0 === eventLoop }, + "Provided EventLoop must be part of clients EventLoopGroup." + ) taskEL = eventLoop case .testOnly_exact(_, delegateOn: let delegateEL): taskEL = delegateEL } - logger.trace("selected EventLoop for task given the preference", - metadata: ["ahc-eventloop": "\(taskEL)", - "ahc-el-preference": "\(eventLoopPreference)"]) + logger.trace( + "selected EventLoop for task given the preference", + metadata: [ + "ahc-eventloop": "\(taskEL)", + "ahc-el-preference": "\(eventLoopPreference)", + ] + ) let failedTask: Task? = self.stateLock.withLock { switch self.state { @@ -618,10 +740,12 @@ public class HTTPClient { return nil case .shuttingDown, .shutDown: logger.debug("client is shutting down, failing request") - return Task.failedTask(eventLoop: taskEL, - error: HTTPClientError.alreadyShutdown, - logger: logger, - makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool) + return Task.failedTask( + eventLoop: taskEL, + error: HTTPClientError.alreadyShutdown, + logger: logger, + makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool + ) } } @@ -644,7 +768,11 @@ public class HTTPClient { } }() - let task = Task(eventLoop: taskEL, logger: logger, makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool) + let task = Task( + eventLoop: taskEL, + logger: logger, + makeOrGetFileIOThreadPool: self.makeOrGetFileIOThreadPool + ) do { let requestBag = try RequestBag( request: request, @@ -656,20 +784,20 @@ public class HTTPClient { delegate: delegate ) - var deadlineSchedule: Scheduled? if let deadline = deadline { - deadlineSchedule = taskEL.scheduleTask(deadline: deadline) { + let deadlineSchedule = taskEL.scheduleTask(deadline: deadline) { requestBag.deadlineExceeded() } task.promise.futureResult.whenComplete { _ in - deadlineSchedule?.cancel() + deadlineSchedule.cancel() } } self.poolManager.executeRequest(requestBag) } catch { - task.fail(with: error, delegateType: Delegate.self) + delegate.didReceiveError(task: task, error) + task.failInternal(with: error) } return task @@ -711,7 +839,12 @@ public class HTTPClient { /// Enables automatic body decompression. Supported algorithms are gzip and deflate. public var decompression: Decompression /// Ignore TLS unclean shutdown error, defaults to `false`. - @available(*, deprecated, message: "AsyncHTTPClient now correctly supports handling unexpected SSL connection drops. This property is ignored") + @available( + *, + deprecated, + message: + "AsyncHTTPClient now correctly supports handling unexpected SSL connection drops. This property is ignored" + ) public var ignoreUncleanSSLShutdown: Bool { get { false } set {} @@ -738,6 +871,19 @@ public class HTTPClient { } } + /// Whether ``HTTPClient`` will use Multipath TCP or not + /// By default, don't use it + public var enableMultipath: Bool + + /// A method with access to the HTTP/1 connection channel that is called when creating the connection. + public var http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + + /// A method with access to the HTTP/2 connection channel that is called when creating the connection. + public var http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + + /// A method with access to the HTTP/2 stream channel that is called when creating the stream. + public var http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? + public init( tlsConfiguration: TLSConfiguration? = nil, redirectConfiguration: RedirectConfiguration? = nil, @@ -755,14 +901,17 @@ public class HTTPClient { self.decompression = decompression self.httpVersion = .automatic self.networkFrameworkWaitForConnectivity = true + self.enableMultipath = false } - public init(tlsConfiguration: TLSConfiguration? = nil, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled) { + public init( + tlsConfiguration: TLSConfiguration? = nil, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled + ) { self.init( tlsConfiguration: tlsConfiguration, redirectConfiguration: redirectConfiguration, @@ -774,49 +923,59 @@ public class HTTPClient { ) } - public init(certificateVerification: CertificateVerification, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - maximumAllowedIdleTimeInConnectionPool: TimeAmount = .seconds(60), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled) { + public init( + certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + maximumAllowedIdleTimeInConnectionPool: TimeAmount = .seconds(60), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled + ) { var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.certificateVerification = certificateVerification - self.init(tlsConfiguration: tlsConfig, - redirectConfiguration: redirectConfiguration, - timeout: timeout, - connectionPool: ConnectionPool(idleTimeout: maximumAllowedIdleTimeInConnectionPool), - proxy: proxy, - ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, - decompression: decompression) + self.init( + tlsConfiguration: tlsConfig, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: ConnectionPool(idleTimeout: maximumAllowedIdleTimeInConnectionPool), + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) } - public init(certificateVerification: CertificateVerification, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - connectionPool: TimeAmount = .seconds(60), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled, - backgroundActivityLogger: Logger?) { + public init( + certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + connectionPool: TimeAmount = .seconds(60), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled, + backgroundActivityLogger: Logger? + ) { var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.certificateVerification = certificateVerification - self.init(tlsConfiguration: tlsConfig, - redirectConfiguration: redirectConfiguration, - timeout: timeout, - connectionPool: ConnectionPool(idleTimeout: connectionPool), - proxy: proxy, - ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, - decompression: decompression) + self.init( + tlsConfiguration: tlsConfig, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: ConnectionPool(idleTimeout: connectionPool), + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) } - public init(certificateVerification: CertificateVerification, - redirectConfiguration: RedirectConfiguration? = nil, - timeout: Timeout = Timeout(), - proxy: Proxy? = nil, - ignoreUncleanSSLShutdown: Bool = false, - decompression: Decompression = .disabled) { + public init( + certificateVerification: CertificateVerification, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled + ) { self.init( certificateVerification: certificateVerification, redirectConfiguration: redirectConfiguration, @@ -827,6 +986,32 @@ public class HTTPClient { decompression: decompression ) } + + public init( + tlsConfiguration: TLSConfiguration? = nil, + redirectConfiguration: RedirectConfiguration? = nil, + timeout: Timeout = Timeout(), + connectionPool: ConnectionPool = ConnectionPool(), + proxy: Proxy? = nil, + ignoreUncleanSSLShutdown: Bool = false, + decompression: Decompression = .disabled, + http1_1ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2ConnectionDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil, + http2StreamChannelDebugInitializer: (@Sendable (Channel) -> EventLoopFuture)? = nil + ) { + self.init( + tlsConfiguration: tlsConfiguration, + redirectConfiguration: redirectConfiguration, + timeout: timeout, + connectionPool: connectionPool, + proxy: proxy, + ignoreUncleanSSLShutdown: ignoreUncleanSSLShutdown, + decompression: decompression + ) + self.http1_1ConnectionDebugInitializer = http1_1ConnectionDebugInitializer + self.http2ConnectionDebugInitializer = http2ConnectionDebugInitializer + self.http2StreamChannelDebugInitializer = http2StreamChannelDebugInitializer + } } /// Specifies how `EventLoopGroup` will be created and establishes lifecycle ownership. @@ -870,7 +1055,7 @@ public class HTTPClient { /// `EventLoop` but will not establish a new network connection just to satisfy the `EventLoop` preference if /// another existing connection on a different `EventLoop` is readily available from a connection pool. public static func delegate(on eventLoop: EventLoop) -> EventLoopPreference { - return EventLoopPreference(.delegate(on: eventLoop)) + EventLoopPreference(.delegate(on: eventLoop)) } /// The delegate and the `Channel` will be run on the specified EventLoop. @@ -878,7 +1063,7 @@ public class HTTPClient { /// Use this for use-cases where you prefer a new connection to be established over re-using an existing /// connection that might be on a different `EventLoop`. public static func delegateAndChannel(on eventLoop: EventLoop) -> EventLoopPreference { - return EventLoopPreference(.delegateAndChannel(on: eventLoop)) + EventLoopPreference(.delegateAndChannel(on: eventLoop)) } } @@ -902,7 +1087,7 @@ public class HTTPClient { extension HTTPClient.EventLoopGroupProvider { /// Shares ``HTTPClient/defaultEventLoopGroup`` which is a singleton `EventLoopGroup` suitable for the platform. public static var singleton: Self { - return .shared(HTTPClient.defaultEventLoopGroup) + .shared(HTTPClient.defaultEventLoopGroup) } } @@ -1005,18 +1190,20 @@ extension HTTPClient.Configuration { /// - allowCycles: Whether cycles are allowed. /// /// - warning: Cycle detection will keep all visited URLs in memory which means a malicious server could use this as a denial-of-service vector. - public static func follow(max: Int, allowCycles: Bool) -> RedirectConfiguration { return .init(configuration: .follow(max: max, allowCycles: allowCycles)) } + public static func follow(max: Int, allowCycles: Bool) -> RedirectConfiguration { + .init(configuration: .follow(max: max, allowCycles: allowCycles)) + } } /// Connection pool configuration. public struct ConnectionPool: Hashable, Sendable { /// Specifies amount of time connections are kept idle in the pool. After this time has passed without a new /// request the connections are closed. - public var idleTimeout: TimeAmount + public var idleTimeout: TimeAmount = .seconds(60) /// The maximum number of connections that are kept alive in the connection pool per host. If requests with /// an explicit eventLoopRequirement are sent, this number might be exceeded due to overflow connections. - public var concurrentHTTP1ConnectionsPerHostSoftLimit: Int + public var concurrentHTTP1ConnectionsPerHostSoftLimit: Int = 8 /// If true, ``HTTPClient`` will try to create new connections on connection failure with an exponential backoff. /// Requests will only fail after the ``HTTPClient/Configuration/Timeout-swift.struct/connect`` timeout exceeded. @@ -1025,16 +1212,17 @@ extension HTTPClient.Configuration { /// - warning: We highly recommend leaving this on. /// It is very common that connections establishment is flaky at scale. /// ``HTTPClient`` will automatically mitigate these kind of issues if this flag is turned on. - var retryConnectionEstablishment: Bool + public var retryConnectionEstablishment: Bool = true + + public init() {} - public init(idleTimeout: TimeAmount = .seconds(60)) { - self.init(idleTimeout: idleTimeout, concurrentHTTP1ConnectionsPerHostSoftLimit: 8) + public init(idleTimeout: TimeAmount) { + self.idleTimeout = idleTimeout } public init(idleTimeout: TimeAmount, concurrentHTTP1ConnectionsPerHostSoftLimit: Int) { self.idleTimeout = idleTimeout self.concurrentHTTP1ConnectionsPerHostSoftLimit = concurrentHTTP1ConnectionsPerHostSoftLimit - self.retryConnectionEstablishment = true } } @@ -1102,7 +1290,7 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { } public var description: String { - return "HTTPClientError.\(String(describing: self.code))" + "HTTPClientError.\(String(describing: self.code))" } /// Short description of the error that can be used in case a bounded set of error descriptions is expected, e.g. to @@ -1192,7 +1380,9 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// URL does not contain scheme. public static let emptyScheme = HTTPClientError(code: .emptyScheme) /// Provided URL scheme is not supported, supported schemes are: `http` and `https` - public static func unsupportedScheme(_ scheme: String) -> HTTPClientError { return HTTPClientError(code: .unsupportedScheme(scheme)) } + public static func unsupportedScheme(_ scheme: String) -> HTTPClientError { + HTTPClientError(code: .unsupportedScheme(scheme)) + } /// Request timed out while waiting for response. public static let readTimeout = HTTPClientError(code: .readTimeout) /// Request timed out. @@ -1221,9 +1411,13 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// A body was sent in a request with method TRACE. public static let traceRequestWithBody = HTTPClientError(code: .traceRequestWithBody) /// Header field names contain invalid characters. - public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldNames(names)) } + public static func invalidHeaderFieldNames(_ names: [String]) -> HTTPClientError { + HTTPClientError(code: .invalidHeaderFieldNames(names)) + } /// Header field values contain invalid characters. - public static func invalidHeaderFieldValues(_ values: [String]) -> HTTPClientError { return HTTPClientError(code: .invalidHeaderFieldValues(values)) } + public static func invalidHeaderFieldValues(_ values: [String]) -> HTTPClientError { + HTTPClientError(code: .invalidHeaderFieldValues(values)) + } /// Body length is not equal to `Content-Length`. public static let bodyLengthMismatch = HTTPClientError(code: .bodyLengthMismatch) /// Body part was written after request was fully sent. @@ -1241,12 +1435,12 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { public static let tlsHandshakeTimeout = HTTPClientError(code: .tlsHandshakeTimeout) /// The remote server only offered an unsupported application protocol public static func serverOfferedUnsupportedApplicationProtocol(_ proto: String) -> HTTPClientError { - return HTTPClientError(code: .serverOfferedUnsupportedApplicationProtocol(proto)) + HTTPClientError(code: .serverOfferedUnsupportedApplicationProtocol(proto)) } /// The globally shared singleton ``HTTPClient`` cannot be shut down. public static var shutdownUnsupported: HTTPClientError { - return HTTPClientError(code: .shutdownUnsupported) + HTTPClientError(code: .shutdownUnsupported) } /// The request deadline was exceeded. The request was cancelled because of this. @@ -1263,6 +1457,11 @@ public struct HTTPClientError: Error, Equatable, CustomStringConvertible { /// - Tasks are not processed fast enough on the existing connections, to process all waiters in time public static let getConnectionFromPoolTimeout = HTTPClientError(code: .getConnectionFromPoolTimeout) - @available(*, deprecated, message: "AsyncHTTPClient now correctly supports informational headers. For this reason `httpEndReceivedAfterHeadWith1xx` will not be thrown anymore.") + @available( + *, + deprecated, + message: + "AsyncHTTPClient now correctly supports informational headers. For this reason `httpEndReceivedAfterHeadWith1xx` will not be thrown anymore." + ) public static let httpEndReceivedAfterHeadWith1xx = HTTPClientError(code: .httpEndReceivedAfterHeadWith1xx) } diff --git a/Sources/AsyncHTTPClient/HTTPHandler.swift b/Sources/AsyncHTTPClient/HTTPHandler.swift index 98415a124..8d92d8ef7 100644 --- a/Sources/AsyncHTTPClient/HTTPHandler.swift +++ b/Sources/AsyncHTTPClient/HTTPHandler.swift @@ -13,7 +13,6 @@ //===----------------------------------------------------------------------===// import Algorithms -import Foundation import Logging import NIOConcurrencyHelpers import NIOCore @@ -21,20 +20,27 @@ import NIOHTTP1 import NIOPosix import NIOSSL +#if compiler(>=6.0) +import Foundation +#else +@preconcurrency import Foundation +#endif + extension HTTPClient { /// A request body. - public struct Body { + public struct Body: Sendable { /// A streaming uploader. /// /// ``StreamWriter`` abstracts - public struct StreamWriter { - let closure: (IOData) -> EventLoopFuture + public struct StreamWriter: Sendable { + let closure: @Sendable (IOData) -> EventLoopFuture /// Create new ``HTTPClient/Body/StreamWriter`` /// /// - parameters: /// - closure: function that will be called to write actual bytes to the channel. - public init(closure: @escaping (IOData) -> EventLoopFuture) { + @preconcurrency + public init(closure: @escaping @Sendable (IOData) -> EventLoopFuture) { self.closure = closure } @@ -43,34 +49,92 @@ extension HTTPClient { /// - parameters: /// - data: `IOData` to write. public func write(_ data: IOData) -> EventLoopFuture { - return self.closure(data) + self.closure(data) } @inlinable - func writeChunks(of bytes: Bytes, maxChunkSize: Int) -> EventLoopFuture where Bytes.Element == UInt8 { - let iterator = UnsafeMutableTransferBox(bytes.chunks(ofCount: maxChunkSize).makeIterator()) - guard let chunk = iterator.wrappedValue.next() else { - return self.write(IOData.byteBuffer(.init())) - } + func writeChunks( + of bytes: Bytes, + maxChunkSize: Int + ) -> EventLoopFuture where Bytes.Element == UInt8, Bytes: Sendable { + // `StreamWriter` has design issues, for example + // - https://github.com/swift-server/async-http-client/issues/194 + // - https://github.com/swift-server/async-http-client/issues/264 + // - We're not told the EventLoop the task runs on and the user is free to return whatever EL they + // want. + // One important consideration then is that we must lock around the iterator because we could be hopping + // between threads. + typealias Iterator = EnumeratedSequence>.Iterator + typealias Chunk = (offset: Int, element: ChunksOfCountCollection.Element) + + // HACK (again, we're not told the right EventLoop): Let's write 0 bytes to make the user tell us... + return self.write(.byteBuffer(ByteBuffer())).flatMapWithEventLoop { (_, loop) in + func makeIteratorAndFirstChunk( + bytes: Bytes + ) -> (iterator: Iterator, chunk: Chunk)? { + var iterator = bytes.chunks(ofCount: maxChunkSize).enumerated().makeIterator() + guard let chunk = iterator.next() else { + return nil + } + + return (iterator, chunk) + } + + guard let iteratorAndChunk = makeIteratorAndFirstChunk(bytes: bytes) else { + return loop.makeSucceededVoidFuture() + } - @Sendable // can't use closure here as we recursively call ourselves which closures can't do - func writeNextChunk(_ chunk: Bytes.SubSequence) -> EventLoopFuture { - if let nextChunk = iterator.wrappedValue.next() { - return self.write(.byteBuffer(ByteBuffer(bytes: chunk))).flatMap { - writeNextChunk(nextChunk) + var iterator = iteratorAndChunk.0 + let chunk = iteratorAndChunk.1 + + // can't use closure here as we recursively call ourselves which closures can't do + func writeNextChunk(_ chunk: Chunk, allDone: EventLoopPromise) { + let loop = allDone.futureResult.eventLoop + loop.assertInEventLoop() + + if let (index, element) = iterator.next() { + self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).hop(to: loop).assumeIsolated().map + { + if (index + 1) % 4 == 0 { + // Let's not stack-overflow if the futures insta-complete which they at least in HTTP/2 + // mode. + // Also, we must frequently return to the EventLoop because we may get the pause signal + // from another thread. If we fail to do that promptly, we may balloon our body chunks + // into memory. + allDone.futureResult.eventLoop.assumeIsolated().execute { + writeNextChunk((offset: index, element: element), allDone: allDone) + } + } else { + writeNextChunk((offset: index, element: element), allDone: allDone) + } + }.nonisolated().cascadeFailure(to: allDone) + } else { + self.write(.byteBuffer(ByteBuffer(bytes: chunk.element))).cascade(to: allDone) } - } else { - return self.write(.byteBuffer(ByteBuffer(bytes: chunk))) } - } - return writeNextChunk(chunk) + let allDone = loop.makePromise(of: Void.self) + writeNextChunk(chunk, allDone: allDone) + return allDone.futureResult + } } } /// Body size. If nil,`Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` /// header is set with the given `length`. - public var length: Int? + @available(*, deprecated, renamed: "contentLength") + public var length: Int? { + get { + self.contentLength.flatMap { Int($0) } + } + set { + self.contentLength = newValue.flatMap { Int64($0) } + } + } + + /// Body size. If nil,`Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` + /// header is set with the given `contentLength`. + public var contentLength: Int64? /// Body chunk provider. public var stream: @Sendable (StreamWriter) -> EventLoopFuture @@ -78,8 +142,8 @@ extension HTTPClient { @usableFromInline typealias StreamCallback = @Sendable (StreamWriter) -> EventLoopFuture @inlinable - init(length: Int?, stream: @escaping StreamCallback) { - self.length = length + init(contentLength: Int64?, stream: @escaping StreamCallback) { + self.contentLength = contentLength.flatMap { $0 } self.stream = stream } @@ -88,7 +152,7 @@ extension HTTPClient { /// - parameters: /// - buffer: Body `ByteBuffer` representation. public static func byteBuffer(_ buffer: ByteBuffer) -> Body { - return Body(length: buffer.readableBytes) { writer in + Body(contentLength: Int64(buffer.readableBytes)) { writer in writer.write(.byteBuffer(buffer)) } } @@ -99,19 +163,37 @@ extension HTTPClient { /// - length: Body size. If nil, `Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` /// header is set with the given `length`. /// - stream: Body chunk provider. + @_disfavoredOverload @preconcurrency - public static func stream(length: Int? = nil, _ stream: @Sendable @escaping (StreamWriter) -> EventLoopFuture) -> Body { - return Body(length: length, stream: stream) + public static func stream( + length: Int? = nil, + _ stream: @Sendable @escaping (StreamWriter) -> EventLoopFuture + ) -> Body { + Body(contentLength: length.flatMap { Int64($0) }, stream: stream) + } + + /// Create and stream body using ``StreamWriter``. + /// + /// - parameters: + /// - contentLength: Body size. If nil, `Transfer-Encoding` will automatically be set to `chunked`. Otherwise a `Content-Length` + /// header is set with the given `contentLength`. + /// - stream: Body chunk provider. + public static func stream( + contentLength: Int64? = nil, + _ stream: @Sendable @escaping (StreamWriter) -> EventLoopFuture + ) -> Body { + Body(contentLength: contentLength, stream: stream) } /// Create and stream body using a collection of bytes. /// /// - parameters: - /// - data: Body binary representation. + /// - bytes: Body binary representation. @preconcurrency @inlinable - public static func bytes(_ bytes: Bytes) -> Body where Bytes: RandomAccessCollection, Bytes: Sendable, Bytes.Element == UInt8 { - return Body(length: bytes.count) { writer in + public static func bytes(_ bytes: Bytes) -> Body + where Bytes: RandomAccessCollection, Bytes: Sendable, Bytes.Element == UInt8 { + Body(contentLength: Int64(bytes.count)) { writer in if bytes.count <= bagOfBytesToByteBufferConversionChunkSize { return writer.write(.byteBuffer(ByteBuffer(bytes: bytes))) } else { @@ -125,7 +207,7 @@ extension HTTPClient { /// - parameters: /// - string: Body `String` representation. public static func string(_ string: String) -> Body { - return Body(length: string.utf8.count) { writer in + Body(contentLength: Int64(string.utf8.count)) { writer in if string.utf8.count <= bagOfBytesToByteBufferConversionChunkSize { return writer.write(.byteBuffer(ByteBuffer(string: string))) } else { @@ -136,7 +218,7 @@ extension HTTPClient { } /// Represents an HTTP request. - public struct Request { + public struct Request: Sendable { /// Request HTTP method, defaults to `GET`. public let method: HTTPMethod /// Remote URL. @@ -161,7 +243,6 @@ extension HTTPClient { /// /// - parameters: /// - url: Remote `URL`. - /// - version: HTTP version. /// - method: HTTP method. /// - headers: Custom HTTP headers. /// - body: Request body. @@ -170,7 +251,12 @@ extension HTTPClient { /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. - public init(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + public init( + url: String, + method: HTTPMethod = .GET, + headers: HTTPHeaders = HTTPHeaders(), + body: Body? = nil + ) throws { try self.init(url: url, method: method, headers: headers, body: body, tlsConfiguration: nil) } @@ -178,7 +264,6 @@ extension HTTPClient { /// /// - parameters: /// - url: Remote `URL`. - /// - version: HTTP version. /// - method: HTTP method. /// - headers: Custom HTTP headers. /// - body: Request body. @@ -188,7 +273,13 @@ extension HTTPClient { /// - `emptyScheme` if URL does not contain HTTP scheme. /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. - public init(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, tlsConfiguration: TLSConfiguration?) throws { + public init( + url: String, + method: HTTPMethod = .GET, + headers: HTTPHeaders = HTTPHeaders(), + body: Body? = nil, + tlsConfiguration: TLSConfiguration? + ) throws { guard let url = URL(string: url) else { throw HTTPClientError.invalidURL } @@ -208,7 +299,8 @@ extension HTTPClient { /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. /// - `missingSocketPath` if URL does not contains a socketPath as an encoded host. - public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws { + public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws + { try self.init(url: url, method: method, headers: headers, body: body, tlsConfiguration: nil) } @@ -225,7 +317,13 @@ extension HTTPClient { /// - `unsupportedScheme` if URL does contains unsupported HTTP scheme. /// - `emptyHost` if URL does not contains a host. /// - `missingSocketPath` if URL does not contains a socketPath as an encoded host. - public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil, tlsConfiguration: TLSConfiguration?) throws { + public init( + url: URL, + method: HTTPMethod = .GET, + headers: HTTPHeaders = HTTPHeaders(), + body: Body? = nil, + tlsConfiguration: TLSConfiguration? + ) throws { self.deconstructedURL = try DeconstructedURL(url: url) self.url = url @@ -258,14 +356,26 @@ extension HTTPClient { head.headers.addHostIfNeeded(for: self.deconstructedURL) - let metadata = try head.headers.validateAndSetTransportFraming(method: self.method, bodyLength: .init(self.body)) + let metadata = try head.headers.validateAndSetTransportFraming( + method: self.method, + bodyLength: .init(self.body) + ) return (head, metadata) } + + /// Set basic auth for a request. + /// + /// - parameters: + /// - username: the username to authenticate with + /// - password: authentication password associated with the username + public mutating func setBasicAuth(username: String, password: String) { + self.headers.setBasicAuth(username: username, password: password) + } } /// Represents an HTTP response. - public struct Response { + public struct Response: Sendable { /// Remote host of the request. public var host: String /// Response HTTP status. @@ -276,6 +386,13 @@ extension HTTPClient { public var headers: HTTPHeaders /// Response body. public var body: ByteBuffer? + /// The history of all requests and responses in redirect order. + public var history: [RequestResponse] + + /// The target URL (after redirects) of the response. + public var url: URL? { + self.history.last?.request.url + } /// Create HTTP `Response`. /// @@ -291,6 +408,30 @@ extension HTTPClient { self.version = HTTPVersion(major: 1, minor: 1) self.headers = headers self.body = body + self.history = [] + } + + /// Create HTTP `Response`. + /// + /// - parameters: + /// - host: Remote host of the request. + /// - status: Response HTTP status. + /// - version: Response HTTP version. + /// - headers: Reponse HTTP headers. + /// - body: Response body. + public init( + host: String, + status: HTTPResponseStatus, + version: HTTPVersion, + headers: HTTPHeaders, + body: ByteBuffer? + ) { + self.host = host + self.status = status + self.version = version + self.headers = headers + self.body = body + self.history = [] } /// Create HTTP `Response`. @@ -301,12 +442,21 @@ extension HTTPClient { /// - version: Response HTTP version. /// - headers: Reponse HTTP headers. /// - body: Response body. - public init(host: String, status: HTTPResponseStatus, version: HTTPVersion, headers: HTTPHeaders, body: ByteBuffer?) { + /// - history: History of all requests and responses in redirect order. + public init( + host: String, + status: HTTPResponseStatus, + version: HTTPVersion, + headers: HTTPHeaders, + body: ByteBuffer?, + history: [RequestResponse] + ) { self.host = host self.status = status self.version = version self.headers = headers self.body = body + self.history = history } } @@ -325,19 +475,19 @@ extension HTTPClient { /// HTTP basic auth. public static func basic(username: String, password: String) -> HTTPClient.Authorization { - return .basic(credentials: Base64.encode(bytes: "\(username):\(password)".utf8)) + .basic(credentials: Base64.encode(bytes: "\(username):\(password)".utf8)) } /// HTTP basic auth. /// /// This version uses the raw string directly. public static func basic(credentials: String) -> HTTPClient.Authorization { - return .init(scheme: .Basic(credentials)) + .init(scheme: .Basic(credentials)) } /// HTTP bearer auth public static func bearer(tokens: String) -> HTTPClient.Authorization { - return .init(scheme: .Bearer(tokens)) + .init(scheme: .Bearer(tokens)) } /// The header string for this auth field. @@ -350,6 +500,16 @@ extension HTTPClient { } } } + + public struct RequestResponse: Sendable { + public var request: Request + public var responseHead: HTTPResponseHead + + public init(request: Request, responseHead: HTTPResponseHead) { + self.request = request + self.responseHead = responseHead + } + } } /// The default ``HTTPClientResponseDelegate``. @@ -374,11 +534,16 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate { } public var description: String { - return "ResponseTooBigError: received response body exceeds maximum accepted size of \(self.maxBodySize) bytes" + "ResponseTooBigError: received response body exceeds maximum accepted size of \(self.maxBodySize) bytes" } } - var state = State.idle + private struct MutableState: Sendable { + var history = [HTTPClient.RequestResponse]() + var state = State.idle + } + + private let state: NIOLockedValueBox let requestMethod: HTTPMethod let requestHost: String @@ -412,84 +577,126 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate { self.requestMethod = request.method self.requestHost = request.host self.maxBodySize = maxBodySize + self.state = NIOLockedValueBox(MutableState()) + } + + public func didVisitURL( + task: HTTPClient.Task, + _ request: HTTPClient.Request, + _ head: HTTPResponseHead + ) { + self.state.withLockedValue { + $0.history.append(.init(request: request, responseHead: head)) + } } public func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - switch self.state { - case .idle: - if self.requestMethod != .HEAD, - let contentLength = head.headers.first(name: "Content-Length"), - let announcedBodySize = Int(contentLength), - announcedBodySize > self.maxBodySize { - let error = ResponseTooBigError(maxBodySize: maxBodySize) - self.state = .error(error) - return task.eventLoop.makeFailedFuture(error) - } + let responseTooBig: Bool + + if self.requestMethod != .HEAD, + let contentLength = head.headers.first(name: "Content-Length"), + let announcedBodySize = Int(contentLength), + announcedBodySize > self.maxBodySize + { + responseTooBig = true + } else { + responseTooBig = false + } - self.state = .head(head) - case .head: - preconditionFailure("head already set") - case .body: - preconditionFailure("no head received before body") - case .end: - preconditionFailure("request already processed") - case .error: - break - } - return task.eventLoop.makeSucceededFuture(()) + return self.state.withLockedValue { + switch $0.state { + case .idle: + if responseTooBig { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + $0.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } + + $0.state = .head(head) + case .head: + preconditionFailure("head already set") + case .body: + preconditionFailure("no head received before body") + case .end: + preconditionFailure("request already processed") + case .error: + break + } + return task.eventLoop.makeSucceededFuture(()) + } } public func didReceiveBodyPart(task: HTTPClient.Task, _ part: ByteBuffer) -> EventLoopFuture { - switch self.state { - case .idle: - preconditionFailure("no head received before body") - case .head(let head): - guard part.readableBytes <= self.maxBodySize else { - let error = ResponseTooBigError(maxBodySize: self.maxBodySize) - self.state = .error(error) - return task.eventLoop.makeFailedFuture(error) - } - self.state = .body(head, part) - case .body(let head, var body): - let newBufferSize = body.writerIndex + part.readableBytes - guard newBufferSize <= self.maxBodySize else { - let error = ResponseTooBigError(maxBodySize: self.maxBodySize) - self.state = .error(error) - return task.eventLoop.makeFailedFuture(error) - } + self.state.withLockedValue { + switch $0.state { + case .idle: + preconditionFailure("no head received before body") + case .head(let head): + guard part.readableBytes <= self.maxBodySize else { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + $0.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } + $0.state = .body(head, part) + case .body(let head, var body): + let newBufferSize = body.writerIndex + part.readableBytes + guard newBufferSize <= self.maxBodySize else { + let error = ResponseTooBigError(maxBodySize: self.maxBodySize) + $0.state = .error(error) + return task.eventLoop.makeFailedFuture(error) + } - // The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's - // a cross-module call in the way) so we need to drop the original reference to `body` in - // `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.end` (which - // has no associated data). We'll fix it at the bottom of this block. - self.state = .end - var part = part - body.writeBuffer(&part) - self.state = .body(head, body) - case .end: - preconditionFailure("request already processed") - case .error: - break - } - return task.eventLoop.makeSucceededFuture(()) + // The compiler can't prove that `self.state` is dead here (and it kinda isn't, there's + // a cross-module call in the way) so we need to drop the original reference to `body` in + // `self.state` or we'll get a CoW. To fix that we temporarily set the state to `.end` (which + // has no associated data). We'll fix it at the bottom of this block. + $0.state = .end + var part = part + body.writeBuffer(&part) + $0.state = .body(head, body) + case .end: + preconditionFailure("request already processed") + case .error: + break + } + return task.eventLoop.makeSucceededFuture(()) + } } public func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.state = .error(error) + self.state.withLockedValue { + $0.state = .error(error) + } } public func didFinishRequest(task: HTTPClient.Task) throws -> Response { - switch self.state { - case .idle: - preconditionFailure("no head received before end") - case .head(let head): - return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: nil) - case .body(let head, let body): - return Response(host: self.requestHost, status: head.status, version: head.version, headers: head.headers, body: body) - case .end: - preconditionFailure("request already processed") - case .error(let error): - throw error + try self.state.withLockedValue { + switch $0.state { + case .idle: + preconditionFailure("no head received before end") + case .head(let head): + return Response( + host: self.requestHost, + status: head.status, + version: head.version, + headers: head.headers, + body: nil, + history: $0.history + ) + case .body(let head, let body): + return Response( + host: self.requestHost, + status: head.status, + version: head.version, + headers: head.headers, + body: body, + history: $0.history + ) + case .end: + preconditionFailure("request already processed") + case .error(let error): + throw error + } } } } @@ -525,8 +732,9 @@ public final class ResponseAccumulator: HTTPClientResponseDelegate { /// released together with the `HTTPTaskHandler` when channel is closed. /// Users of the library are not required to keep a reference to the /// object that implements this protocol, but may do so if needed. -public protocol HTTPClientResponseDelegate: AnyObject { - associatedtype Response +@preconcurrency +public protocol HTTPClientResponseDelegate: AnyObject, Sendable { + associatedtype Response: Sendable /// Called when the request head is sent. Will be called once. /// @@ -548,7 +756,16 @@ public protocol HTTPClientResponseDelegate: AnyObject { /// - task: Current request context. func didSendRequest(task: HTTPClient.Task) - /// Called when response head is received. Will be called once. + /// Called each time a response head is received (including redirects), and always called before ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``. + /// You can use this method to keep an entire history of the request/response chain. + /// + /// - parameters: + /// - task: Current request context. + /// - request: The request that was sent. + /// - head: Received response head. + func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) + + /// Called when the final response head is received (after redirects). /// You must return an `EventLoopFuture` that you complete when you have finished processing the body part. /// You can create an already succeeded future by calling `task.eventLoop.makeSucceededFuture(())`. /// @@ -614,18 +831,23 @@ extension HTTPClientResponseDelegate { /// By default, this does nothing. public func didSendRequest(task: HTTPClient.Task) {} + /// Default implementation of ``HTTPClientResponseDelegate/didVisitURL(task:_:_:)-2el9y``. + /// + /// By default, this does nothing. + public func didVisitURL(task: HTTPClient.Task, _: HTTPClient.Request, _: HTTPResponseHead) {} + /// Default implementation of ``HTTPClientResponseDelegate/didReceiveHead(task:_:)-9r4xd``. /// /// By default, this does nothing. public func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { - return task.eventLoop.makeSucceededVoidFuture() + task.eventLoop.makeSucceededVoidFuture() } /// Default implementation of ``HTTPClientResponseDelegate/didReceiveBodyPart(task:_:)-4fd4v``. /// /// By default, this does nothing. public func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { - return task.eventLoop.makeSucceededVoidFuture() + task.eventLoop.makeSucceededVoidFuture() } /// Default implementation of ``HTTPClientResponseDelegate/didReceiveError(task:_:)-fhsg``. @@ -639,7 +861,7 @@ extension URL { if self.path.isEmpty { return "/" } - return URLComponents(url: self, resolvingAgainstBaseURL: false)?.percentEncodedPath ?? self.path + return URLComponents(url: self, resolvingAgainstBaseURL: true)?.percentEncodedPath ?? self.path } var uri: String { @@ -653,7 +875,7 @@ extension URL { } func hasTheSameOrigin(as other: URL) -> Bool { - return self.host == other.host && self.scheme == other.scheme && self.port == other.port + self.host == other.host && self.scheme == other.scheme && self.port == other.port } /// Initializes a newly created HTTP URL connecting to a unix domain socket path. The socket path is encoded as the URL's host, replacing percent encoding invalid path characters, and will use the "http+unix" scheme. @@ -687,7 +909,7 @@ extension URL { } } -protocol HTTPClientTaskDelegate { +protocol HTTPClientTaskDelegate: Sendable { func fail(_ error: Error) } @@ -696,58 +918,67 @@ extension HTTPClient { /// /// Will be created by the library and could be used for obtaining /// `EventLoopFuture` of the execution or cancellation of the execution. - public final class Task { + public final class Task: Sendable { /// The `EventLoop` the delegate will be executed on. public let eventLoop: EventLoop /// The `Logger` used by the `Task` for logging. - public let logger: Logger // We are okay to store the logger here because a Task is for only one request. + public let logger: Logger // We are okay to store the logger here because a Task is for only one request. let promise: EventLoopPromise + struct State: Sendable { + var isCancelled: Bool + var taskDelegate: HTTPClientTaskDelegate? + } + + private let state: NIOLockedValueBox + var isCancelled: Bool { - self.lock.withLock { self._isCancelled } + self.state.withLockedValue { $0.isCancelled } } var taskDelegate: HTTPClientTaskDelegate? { get { - self.lock.withLock { self._taskDelegate } + self.state.withLockedValue { $0.taskDelegate } } set { - self.lock.withLock { self._taskDelegate = newValue } + self.state.withLockedValue { $0.taskDelegate = newValue } } } - private var _isCancelled: Bool = false - private var _taskDelegate: HTTPClientTaskDelegate? - private let lock = NIOLock() - private let makeOrGetFileIOThreadPool: () -> NIOThreadPool + private let makeOrGetFileIOThreadPool: @Sendable () -> NIOThreadPool /// The shared thread pool of a ``HTTPClient`` used for file IO. It is lazily created on first access. internal var fileIOThreadPool: NIOThreadPool { self.makeOrGetFileIOThreadPool() } - init(eventLoop: EventLoop, logger: Logger, makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool) { + init(eventLoop: EventLoop, logger: Logger, makeOrGetFileIOThreadPool: @escaping @Sendable () -> NIOThreadPool) { self.eventLoop = eventLoop self.promise = eventLoop.makePromise() self.logger = logger self.makeOrGetFileIOThreadPool = makeOrGetFileIOThreadPool + self.state = NIOLockedValueBox(State(isCancelled: false, taskDelegate: nil)) } static func failedTask( eventLoop: EventLoop, error: Error, logger: Logger, - makeOrGetFileIOThreadPool: @escaping () -> NIOThreadPool + makeOrGetFileIOThreadPool: @escaping @Sendable () -> NIOThreadPool ) -> Task { - let task = self.init(eventLoop: eventLoop, logger: logger, makeOrGetFileIOThreadPool: makeOrGetFileIOThreadPool) + let task = self.init( + eventLoop: eventLoop, + logger: logger, + makeOrGetFileIOThreadPool: makeOrGetFileIOThreadPool + ) task.promise.fail(error) return task } /// `EventLoopFuture` for the response returned by this request. public var futureResult: EventLoopFuture { - return self.promise.futureResult + self.promise.futureResult } /// Waits for execution of this request to complete. @@ -755,56 +986,58 @@ extension HTTPClient { /// - returns: The value of ``futureResult`` when it completes. /// - throws: The error value of ``futureResult`` if it errors. @available(*, noasync, message: "wait() can block indefinitely, prefer get()", renamed: "get()") - public func wait() throws -> Response { - return try self.promise.futureResult.wait() + @preconcurrency + public func wait() throws -> Response where Response: Sendable { + try self.promise.futureResult.wait() } /// Provides the result of this request. /// + /// - warning: This method may violates Structured Concurrency because doesn't respect cancellation. + /// /// - returns: The value of ``futureResult`` when it completes. /// - throws: The error value of ``futureResult`` if it errors. @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) - public func get() async throws -> Response { - return try await self.promise.futureResult.get() + @preconcurrency + public func get() async throws -> Response where Response: Sendable { + try await self.promise.futureResult.get() } - /// Cancels the request execution. + /// Initiate cancellation of a HTTP request. + /// + /// This method will return immeidately and doesn't wait for the cancellation to complete. public func cancel() { self.fail(reason: HTTPClientError.cancelled) } - /// Cancels the request execution with a custom `Error`. - /// - Parameter reason: the error that is used to fail the promise + /// Initiate cancellation of a HTTP request with an `error`. + /// + /// This method will return immeidately and doesn't wait for the cancellation to complete. + /// + /// - Parameter error: the error that is used to fail the promise public func fail(reason error: Error) { - let taskDelegate = self.lock.withLock { () -> HTTPClientTaskDelegate? in - self._isCancelled = true - return self._taskDelegate + let taskDelegate = self.state.withLockedValue { state in + state.isCancelled = true + return state.taskDelegate } taskDelegate?.fail(error) } - func succeed(promise: EventLoopPromise?, - with value: Response, - delegateType: Delegate.Type, - closing: Bool) { - promise?.succeed(value) - } - - func fail(with error: Error, - delegateType: Delegate.Type) { + /// Called internally only, used to fail a task from within the state machine functionality. + func failInternal( + with error: Error + ) { self.promise.fail(error) } } } -extension HTTPClient.Task: @unchecked Sendable {} - internal struct TaskCancelEvent {} // MARK: - RedirectHandler -internal struct RedirectHandler { +internal struct RedirectHandler { let request: HTTPClient.Request let redirectState: RedirectState let execute: (HTTPClient.Request, RedirectState) -> HTTPClient.Task @@ -821,7 +1054,7 @@ internal struct RedirectHandler { status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise - ) { + ) -> HTTPClient.Task? { do { var redirectState = self.redirectState try redirectState.redirect(to: redirectURL.absoluteString) @@ -841,13 +1074,19 @@ internal struct RedirectHandler { headers: headers, body: body ) - self.execute(newRequest, redirectState).futureResult.whenComplete { result in + + let newTask = self.execute(newRequest, redirectState) + + newTask.futureResult.whenComplete { result in promise.futureResult.eventLoop.execute { promise.completeWith(result) } } + + return newTask } catch { promise.fail(error) + return nil } } } @@ -858,7 +1097,7 @@ extension RequestBodyLength { self = .known(0) return } - guard let length = body.length else { + guard let length = body.contentLength else { self = .unknown return } diff --git a/Sources/AsyncHTTPClient/LRUCache.swift b/Sources/AsyncHTTPClient/LRUCache.swift index 0a01da0d2..f8b58c36a 100644 --- a/Sources/AsyncHTTPClient/LRUCache.swift +++ b/Sources/AsyncHTTPClient/LRUCache.swift @@ -52,9 +52,11 @@ struct LRUCache { @discardableResult mutating func append(key: Key, value: Value) -> Value { - let newElement = Element(generation: self.generation, - key: key, - value: value) + let newElement = Element( + generation: self.generation, + key: key, + value: value + ) if let found = self.bumpGenerationAndFindIndex(key: key) { self.elements[found] = newElement return value diff --git a/Sources/AsyncHTTPClient/NIOLoopBound+Execute.swift b/Sources/AsyncHTTPClient/NIOLoopBound+Execute.swift new file mode 100644 index 000000000..b25a0f00d --- /dev/null +++ b/Sources/AsyncHTTPClient/NIOLoopBound+Execute.swift @@ -0,0 +1,28 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +import NIOCore + +extension NIOLoopBound { + @inlinable + func execute(_ body: @Sendable @escaping (Value) -> Void) { + if self.eventLoop.inEventLoop { + body(self.value) + } else { + self.eventLoop.execute { + body(self.value) + } + } + } +} diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift b/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift index 9796bc2af..148b4a4c4 100644 --- a/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift +++ b/Sources/AsyncHTTPClient/NIOTransportServices/NWErrorHandler.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -#if canImport(Network) -import Network -#endif import NIOCore import NIOHTTP1 import NIOTransportServices +#if canImport(Network) +import Network +#endif + extension HTTPClient { #if canImport(Network) /// A wrapper for `POSIX` errors thrown by `Network.framework`. @@ -38,7 +39,7 @@ extension HTTPClient { self.reason = reason } - public var description: String { return self.reason } + public var description: String { self.reason } } /// A wrapper for TLS errors thrown by `Network.framework`. @@ -58,7 +59,7 @@ extension HTTPClient { self.reason = reason } - public var description: String { return self.reason } + public var description: String { self.reason } } #endif diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift b/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift index 3474a8821..d7c6055ec 100644 --- a/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift +++ b/Sources/AsyncHTTPClient/NIOTransportServices/NWWaitingHandler.swift @@ -33,7 +33,10 @@ final class NWWaitingHandler: ChannelInbound func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { if let waitingEvent = event as? NIOTSNetworkEvents.WaitingForConnectivity { - self.requester.waitingForConnectivity(self.connectionID, error: HTTPClient.NWErrorHandler.translateError(waitingEvent.transientError)) + self.requester.waitingForConnectivity( + self.connectionID, + error: HTTPClient.NWErrorHandler.translateError(waitingEvent.transientError) + ) } context.fireUserInboundEventTriggered(event) } diff --git a/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift b/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift index cb6bd43bd..e8278e095 100644 --- a/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift +++ b/Sources/AsyncHTTPClient/NIOTransportServices/TLSConfiguration.swift @@ -60,13 +60,16 @@ extension TLSVersion { @available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) extension TLSConfiguration { /// Dispatch queue used by Network framework TLS to control certificate verification - static var tlsDispatchQueue = DispatchQueue(label: "TLSDispatch") + static let tlsDispatchQueue = DispatchQueue(label: "TLSDispatch") /// create NWProtocolTLS.Options for use with NIOTransportServices from the NIOSSL TLSConfiguration /// /// - Parameter eventLoop: EventLoop to wait for creation of options on /// - Returns: Future holding NWProtocolTLS Options - func getNWProtocolTLSOptions(on eventLoop: EventLoop, serverNameIndicatorOverride: String?) -> EventLoopFuture { + func getNWProtocolTLSOptions( + on eventLoop: EventLoop, + serverNameIndicatorOverride: String? + ) -> EventLoopFuture { let promise = eventLoop.makePromise(of: NWProtocolTLS.Options.self) Self.tlsDispatchQueue.async { do { @@ -86,11 +89,11 @@ extension TLSConfiguration { let options = NWProtocolTLS.Options() let useMTELGExplainer = """ - You can still use this configuration option on macOS if you initialize HTTPClient \ - with a MultiThreadedEventLoopGroup. Please note that using MultiThreadedEventLoopGroup \ - will make AsyncHTTPClient use NIO on BSD Sockets and not Network.framework (which is the preferred \ - platform networking stack). - """ + You can still use this configuration option on macOS if you initialize HTTPClient \ + with a MultiThreadedEventLoopGroup. Please note that using MultiThreadedEventLoopGroup \ + will make AsyncHTTPClient use NIO on BSD Sockets and not Network.framework (which is the preferred \ + platform networking stack). + """ if let serverNameIndicatorOverride = serverNameIndicatorOverride { serverNameIndicatorOverride.withCString { serverNameIndicatorOverride in @@ -100,15 +103,24 @@ extension TLSConfiguration { // minimum TLS protocol if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) { - sec_protocol_options_set_min_tls_protocol_version(options.securityProtocolOptions, self.minimumTLSVersion.nwTLSProtocolVersion) + sec_protocol_options_set_min_tls_protocol_version( + options.securityProtocolOptions, + self.minimumTLSVersion.nwTLSProtocolVersion + ) } else { - sec_protocol_options_set_tls_min_version(options.securityProtocolOptions, self.minimumTLSVersion.sslProtocol) + sec_protocol_options_set_tls_min_version( + options.securityProtocolOptions, + self.minimumTLSVersion.sslProtocol + ) } // maximum TLS protocol if let maximumTLSVersion = self.maximumTLSVersion { if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 6.0, *) { - sec_protocol_options_set_max_tls_protocol_version(options.securityProtocolOptions, maximumTLSVersion.nwTLSProtocolVersion) + sec_protocol_options_set_max_tls_protocol_version( + options.securityProtocolOptions, + maximumTLSVersion.nwTLSProtocolVersion + ) } else { sec_protocol_options_set_tls_max_version(options.securityProtocolOptions, maximumTLSVersion.sslProtocol) } @@ -161,8 +173,10 @@ extension TLSConfiguration { break } - precondition(self.certificateVerification != .noHostnameVerification, - "TLSConfiguration.certificateVerification = .noHostnameVerification is not supported. \(useMTELGExplainer)") + precondition( + self.certificateVerification != .noHostnameVerification, + "TLSConfiguration.certificateVerification = .noHostnameVerification is not supported. \(useMTELGExplainer)" + ) if certificateVerification != .fullVerification || trustRoots != nil { // add verify block to control certificate verification @@ -196,7 +210,8 @@ extension TLSConfiguration { } } } - }, Self.tlsDispatchQueue + }, + Self.tlsDispatchQueue ) } return options diff --git a/Sources/AsyncHTTPClient/RedirectState.swift b/Sources/AsyncHTTPClient/RedirectState.swift index c4e427ef1..95de2d508 100644 --- a/Sources/AsyncHTTPClient/RedirectState.swift +++ b/Sources/AsyncHTTPClient/RedirectState.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -import struct Foundation.URL import NIOHTTP1 +import struct Foundation.URL + typealias RedirectMode = HTTPClient.Configuration.RedirectConfiguration.Mode struct RedirectState { diff --git a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift index e7fad6850..37b2a42f0 100644 --- a/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift +++ b/Sources/AsyncHTTPClient/RequestBag+StateMachine.swift @@ -12,10 +12,11 @@ // //===----------------------------------------------------------------------===// -import struct Foundation.URL import NIOCore import NIOHTTP1 +import struct Foundation.URL + extension HTTPClient { /// The maximum body size allowed, before a redirect response is cancelled. 3KB. /// @@ -302,10 +303,12 @@ extension RequestBag.StateMachine { preconditionFailure("If we receive a response, we must not have received something else before") } - if let redirectHandler = redirectHandler, let redirectURL = redirectHandler.redirectTarget( - status: head.status, - responseHeaders: head.headers - ) { + if let redirectHandler = redirectHandler, + let redirectURL = redirectHandler.redirectTarget( + status: head.status, + responseHeaders: head.headers + ) + { // If we will redirect, we need to consume the response's body ASAP, to be able to // reuse the existing connection. We will consume a response body, if the body is // smaller than 3kb. @@ -348,7 +351,9 @@ extension RequestBag.StateMachine { case .executing(let executor, let requestState, .buffering(var currentBuffer, next: let next)): guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + preconditionFailure( + "If we have received an error or eof before, why did we get another body part? Next: \(next)" + ) } self.state = .modifying @@ -405,7 +410,9 @@ extension RequestBag.StateMachine { case .executing(let executor, let requestState, .buffering(var buffer, next: let next)): guard case .askExecutorForMore = next else { - preconditionFailure("If we have received an error or eof before, why did we get another body part? Next: \(next)") + preconditionFailure( + "If we have received an error or eof before, why did we get another body part? Next: \(next)" + ) } if buffer.isEmpty, let newChunks = newChunks, !newChunks.isEmpty { @@ -463,7 +470,9 @@ extension RequestBag.StateMachine { case .initialized, .queued, .deadlineExceededWhileQueued: preconditionFailure("Invalid state: \(self.state)") case .executing(_, _, .initialized): - preconditionFailure("Invalid state: Must have received response head, before this method is called for the first time") + preconditionFailure( + "Invalid state: Must have received response head, before this method is called for the first time" + ) case .executing(_, _, .buffering(_, next: .error(let connectionError))): // if an error was received from the connection, we fail the task with the one @@ -476,17 +485,23 @@ extension RequestBag.StateMachine { return .failTask(error, executorToCancel: executor) case .executing(_, _, .waitingForRemote): - preconditionFailure("Invalid state... We just returned from a consumption function. We can't already be waiting") + preconditionFailure( + "Invalid state... We just returned from a consumption function. We can't already be waiting" + ) case .redirected: - preconditionFailure("Invalid state... Redirect don't call out to delegate functions. Thus we should never land here.") + preconditionFailure( + "Invalid state... Redirect don't call out to delegate functions. Thus we should never land here." + ) case .finished(error: .some): // don't overwrite existing errors return .doNothing case .finished(error: .none): - preconditionFailure("Invalid state... If no error occured, this must not be called, after the request was finished") + preconditionFailure( + "Invalid state... If no error occured, this must not be called, after the request was finished" + ) case .modifying: preconditionFailure() @@ -499,7 +514,9 @@ extension RequestBag.StateMachine { preconditionFailure("Invalid state: \(self.state)") case .executing(_, _, .initialized): - preconditionFailure("Invalid state: Must have received response head, before this method is called for the first time") + preconditionFailure( + "Invalid state: Must have received response head, before this method is called for the first time" + ) case .executing(let executor, let requestState, .buffering(var buffer, next: .askExecutorForMore)): self.state = .modifying @@ -529,7 +546,9 @@ extension RequestBag.StateMachine { return .failTask(error, executorToCancel: nil) case .executing(_, _, .waitingForRemote): - preconditionFailure("Invalid state... We just returned from a consumption function. We can't already be waiting") + preconditionFailure( + "Invalid state... We just returned from a consumption function. We can't already be waiting" + ) case .redirected: return .doNothing @@ -538,7 +557,9 @@ extension RequestBag.StateMachine { return .doNothing case .finished(error: .none): - preconditionFailure("Invalid state... If no error occurred, this must not be called, after the request was finished") + preconditionFailure( + "Invalid state... If no error occurred, this must not be called, after the request was finished" + ) case .modifying: preconditionFailure() @@ -559,11 +580,11 @@ extension RequestBag.StateMachine { return .cancelScheduler(queuer) case .initialized, - .deadlineExceededWhileQueued, - .executing, - .finished, - .redirected, - .modifying: + .deadlineExceededWhileQueued, + .executing, + .finished, + .redirected, + .modifying: /// if we are not in the queued state, we can fail early by just calling down to `self.fail(_:)` /// which does the appropriate state transition for us. return .fail(self.fail(HTTPClientError.deadlineExceeded)) diff --git a/Sources/AsyncHTTPClient/RequestBag.swift b/Sources/AsyncHTTPClient/RequestBag.swift index c5472fc6f..f206325ee 100644 --- a/Sources/AsyncHTTPClient/RequestBag.swift +++ b/Sources/AsyncHTTPClient/RequestBag.swift @@ -18,7 +18,8 @@ import NIOCore import NIOHTTP1 import NIOSSL -final class RequestBag { +@preconcurrency +final class RequestBag: Sendable { /// Defends against the call stack getting too large when consuming body parts. /// /// If the response body comes in lots of tiny chunks, we'll deliver those tiny chunks to users @@ -35,13 +36,23 @@ final class RequestBag { } private let delegate: Delegate - private var request: HTTPClient.Request - // the request state is synchronized on the task eventLoop - private var state: StateMachine + struct LoopBoundState: @unchecked Sendable { + // The 'StateMachine' *isn't* Sendable (it holds various objects which aren't). This type + // needs to be sendable so that we can construct a loop bound box off of the event loop + // to hold this state and then subsequently only access it from the event loop. This needs + // to happen so that the request bag can be constructed off of the event loop. If it's + // constructed on the event loop then there's a timing window between users issuing + // a request and calling shutdown where the underlying pool doesn't know about the request + // so the shutdown call may cancel it. + var request: HTTPClient.Request + var state: StateMachine + var consumeBodyPartStackDepth: Int + // if a redirect occurs, we store the task for it so we can propagate cancellation + var redirectTask: HTTPClient.Task? = nil + } - // the consume body part stack depth is synchronized on the task event loop. - private var consumeBodyPartStackDepth: Int + private let loopBoundState: NIOLoopBoundBox // MARK: HTTPClientTask properties @@ -58,19 +69,27 @@ final class RequestBag { let eventLoopPreference: HTTPClient.EventLoopPreference - init(request: HTTPClient.Request, - eventLoopPreference: HTTPClient.EventLoopPreference, - task: HTTPClient.Task, - redirectHandler: RedirectHandler?, - connectionDeadline: NIODeadline, - requestOptions: RequestOptions, - delegate: Delegate) throws { + let tlsConfiguration: TLSConfiguration? + + init( + request: HTTPClient.Request, + eventLoopPreference: HTTPClient.EventLoopPreference, + task: HTTPClient.Task, + redirectHandler: RedirectHandler?, + connectionDeadline: NIODeadline, + requestOptions: RequestOptions, + delegate: Delegate + ) throws { self.poolKey = .init(request, dnsOverride: requestOptions.dnsOverride) self.eventLoopPreference = eventLoopPreference self.task = task - self.state = .init(redirectHandler: redirectHandler) - self.consumeBodyPartStackDepth = 0 - self.request = request + + let loopBoundState = LoopBoundState( + request: request, + state: StateMachine(redirectHandler: redirectHandler), + consumeBodyPartStackDepth: 0 + ) + self.loopBoundState = NIOLoopBoundBox.makeBoxSendingValue(loopBoundState, eventLoop: task.eventLoop) self.connectionDeadline = connectionDeadline self.requestOptions = requestOptions self.delegate = delegate @@ -79,6 +98,8 @@ final class RequestBag { self.requestHead = head self.requestFramingMetadata = metadata + self.tlsConfiguration = request.tlsConfiguration + self.task.taskDelegate = self self.task.futureResult.whenComplete { _ in self.task.taskDelegate = nil @@ -87,22 +108,19 @@ final class RequestBag { private func requestWasQueued0(_ scheduler: HTTPRequestScheduler) { self.logger.debug("Request was queued (waiting for a connection to become available)") - - self.task.eventLoop.assertInEventLoop() - self.state.requestWasQueued(scheduler) + self.loopBoundState.value.state.requestWasQueued(scheduler) } // MARK: - Request - private func willExecuteRequest0(_ executor: HTTPRequestExecutor) { - self.task.eventLoop.assertInEventLoop() - let action = self.state.willExecuteRequest(executor) + let action = self.loopBoundState.value.state.willExecuteRequest(executor) switch action { case .cancelExecuter(let executor): executor.cancelRequest(self) case .failTaskAndCancelExecutor(let error, let executor): self.delegate.didReceiveError(task: self.task, error) - self.task.fail(with: error, delegateType: Delegate.self) + self.task.failInternal(with: error) executor.cancelRequest(self) case .none: break @@ -110,26 +128,22 @@ final class RequestBag { } private func requestHeadSent0() { - self.task.eventLoop.assertInEventLoop() - self.delegate.didSendRequestHead(task: self.task, self.requestHead) - if self.request.body == nil { + if self.loopBoundState.value.request.body == nil { self.delegate.didSendRequest(task: self.task) } } private func resumeRequestBodyStream0() { - self.task.eventLoop.assertInEventLoop() - - let produceAction = self.state.resumeRequestBodyStream() + let produceAction = self.loopBoundState.value.state.resumeRequestBodyStream() switch produceAction { case .startWriter: - guard let body = self.request.body else { + guard let body = self.loopBoundState.value.request.body else { preconditionFailure("Expected to have a body, if the `HTTPRequestStateMachine` resume a request stream") } - self.request.body = nil + self.loopBoundState.value.request.body = nil let writer = HTTPClient.Body.StreamWriter { self.writeNextRequestPart($0) @@ -148,9 +162,7 @@ final class RequestBag { } private func pauseRequestBodyStream0() { - self.task.eventLoop.assertInEventLoop() - - self.state.pauseRequestBodyStream() + self.loopBoundState.value.state.pauseRequestBodyStream() } private func writeNextRequestPart(_ part: IOData) -> EventLoopFuture { @@ -164,14 +176,12 @@ final class RequestBag { } private func writeNextRequestPart0(_ part: IOData) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() - - let action = self.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop) + let action = self.loopBoundState.value.state.writeNextRequestPart(part, taskEventLoop: self.task.eventLoop) switch action { case .failTask(let error): self.delegate.didReceiveError(task: self.task, error) - self.task.fail(with: error, delegateType: Delegate.self) + self.task.failInternal(with: error) return self.task.eventLoop.makeFailedFuture(error) case .failFuture(let error): @@ -188,9 +198,7 @@ final class RequestBag { } private func finishRequestBodyStream(_ result: Result) { - self.task.eventLoop.assertInEventLoop() - - let action = self.state.finishRequestBodyStream(result) + let action = self.loopBoundState.value.state.finishRequestBodyStream(result) switch action { case .none: @@ -221,10 +229,10 @@ final class RequestBag { // MARK: - Response - private func receiveResponseHead0(_ head: HTTPResponseHead) { - self.task.eventLoop.assertInEventLoop() + self.delegate.didVisitURL(task: self.task, self.loopBoundState.value.request, head) // runs most likely on channel eventLoop - switch self.state.receiveResponseHead(head) { + switch self.loopBoundState.value.state.receiveResponseHead(head) { case .none: break @@ -232,7 +240,11 @@ final class RequestBag { executor.demandResponseBodyStream(self) case .redirect(let executor, let handler, let head, let newURL): - handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) executor.cancelRequest(self) case .forwardResponseHead(let head): @@ -246,9 +258,7 @@ final class RequestBag { } private func receiveResponseBodyParts0(_ buffer: CircularBuffer) { - self.task.eventLoop.assertInEventLoop() - - switch self.state.receiveResponseBodyParts(buffer) { + switch self.loopBoundState.value.state.receiveResponseBodyParts(buffer) { case .none: break @@ -256,7 +266,11 @@ final class RequestBag { executor.demandResponseBodyStream(self) case .redirect(let executor, let handler, let head, let newURL): - handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) executor.cancelRequest(self) case .forwardResponsePart(let part): @@ -270,8 +284,7 @@ final class RequestBag { } private func succeedRequest0(_ buffer: CircularBuffer?) { - self.task.eventLoop.assertInEventLoop() - let action = self.state.succeedRequest(buffer) + let action = self.loopBoundState.value.state.succeedRequest(buffer) switch action { case .none: @@ -292,13 +305,15 @@ final class RequestBag { } case .redirect(let handler, let head, let newURL): - handler.redirect(status: head.status, to: newURL, promise: self.task.promise) + self.loopBoundState.value.redirectTask = handler.redirect( + status: head.status, + to: newURL, + promise: self.task.promise + ) } } private func consumeMoreBodyData0(resultOfPreviousConsume result: Result) { - self.task.eventLoop.assertInEventLoop() - // We get defensive here about the maximum stack depth. It's possible for the `didReceiveBodyPart` // future to be returned to us completed. If it is, we will recurse back into this method. To // break that recursion we have a max stack depth which we increment and decrement in this method: @@ -309,24 +324,27 @@ final class RequestBag { // that risk ending up in this loop. That's because we don't need an accurate count: our limit is // a best-effort target anyway, one stack frame here or there does not put us at risk. We're just // trying to prevent ourselves looping out of control. - self.consumeBodyPartStackDepth += 1 + self.loopBoundState.value.consumeBodyPartStackDepth += 1 defer { - self.consumeBodyPartStackDepth -= 1 - assert(self.consumeBodyPartStackDepth >= 0) + self.loopBoundState.value.consumeBodyPartStackDepth -= 1 + assert(self.loopBoundState.value.consumeBodyPartStackDepth >= 0) } - let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result) + let consumptionAction = self.loopBoundState.value.state.consumeMoreBodyData( + resultOfPreviousConsume: result + ) switch consumptionAction { case .consume(let byteBuffer): self.delegate.didReceiveBodyPart(task: self.task, byteBuffer) .hop(to: self.task.eventLoop) + .assumeIsolated() .whenComplete { result in - if self.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth { + if self.loopBoundState.value.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth { self.consumeMoreBodyData0(resultOfPreviousConsume: result) } else { // We need to unwind the stack, let's take a break. - self.task.eventLoop.execute { + self.task.eventLoop.assumeIsolated().execute { self.consumeMoreBodyData0(resultOfPreviousConsume: result) } } @@ -337,7 +355,7 @@ final class RequestBag { case .finishStream: do { let response = try self.delegate.didFinishRequest(task: self.task) - self.task.promise.succeed(response) + self.task.promise.assumeIsolated().succeed(response) } catch { self.task.promise.fail(error) } @@ -351,11 +369,11 @@ final class RequestBag { } private func fail0(_ error: Error) { - self.task.eventLoop.assertInEventLoop() - - let action = self.state.fail(error) + let action = self.loopBoundState.value.state.fail(error) self.executeFailAction0(action) + + self.loopBoundState.value.redirectTask?.fail(reason: error) } private func executeFailAction0(_ action: RequestBag.StateMachine.FailAction) { @@ -372,8 +390,7 @@ final class RequestBag { } func deadlineExceeded0() { - self.task.eventLoop.assertInEventLoop() - let action = self.state.deadlineExceeded() + let action = self.loopBoundState.value.state.deadlineExceeded() switch action { case .cancelScheduler(let scheduler): @@ -395,9 +412,6 @@ final class RequestBag { } extension RequestBag: HTTPSchedulableRequest, HTTPClientTaskDelegate { - var tlsConfiguration: TLSConfiguration? { - self.request.tlsConfiguration - } func requestWasQueued(_ scheduler: HTTPRequestScheduler) { if self.task.eventLoop.inEventLoop { @@ -435,8 +449,8 @@ extension RequestBag: HTTPExecutableRequest { case .indifferent: return self.task.eventLoop case .delegate(let eventLoop), - .delegateAndChannel(on: let eventLoop), - .testOnly_exact(channelOn: let eventLoop, delegateOn: _): + .delegateAndChannel(on: let eventLoop), + .testOnly_exact(channelOn: let eventLoop, delegateOn: _): return eventLoop } } diff --git a/Sources/AsyncHTTPClient/RequestValidation.swift b/Sources/AsyncHTTPClient/RequestValidation.swift index 87224a3b2..f338e06a9 100644 --- a/Sources/AsyncHTTPClient/RequestValidation.swift +++ b/Sources/AsyncHTTPClient/RequestValidation.swift @@ -50,23 +50,23 @@ extension HTTPHeaders { let satisfy = name.utf8.allSatisfy { char -> Bool in switch char { case UInt8(ascii: "a")...UInt8(ascii: "z"), - UInt8(ascii: "A")...UInt8(ascii: "Z"), - UInt8(ascii: "0")...UInt8(ascii: "9"), - UInt8(ascii: "!"), - UInt8(ascii: "#"), - UInt8(ascii: "$"), - UInt8(ascii: "%"), - UInt8(ascii: "&"), - UInt8(ascii: "'"), - UInt8(ascii: "*"), - UInt8(ascii: "+"), - UInt8(ascii: "-"), - UInt8(ascii: "."), - UInt8(ascii: "^"), - UInt8(ascii: "_"), - UInt8(ascii: "`"), - UInt8(ascii: "|"), - UInt8(ascii: "~"): + UInt8(ascii: "A")...UInt8(ascii: "Z"), + UInt8(ascii: "0")...UInt8(ascii: "9"), + UInt8(ascii: "!"), + UInt8(ascii: "#"), + UInt8(ascii: "$"), + UInt8(ascii: "%"), + UInt8(ascii: "&"), + UInt8(ascii: "'"), + UInt8(ascii: "*"), + UInt8(ascii: "+"), + UInt8(ascii: "-"), + UInt8(ascii: "."), + UInt8(ascii: "^"), + UInt8(ascii: "_"), + UInt8(ascii: "`"), + UInt8(ascii: "|"), + UInt8(ascii: "~"): return true default: return false @@ -166,13 +166,14 @@ extension HTTPHeaders { mutating func addHostIfNeeded(for url: DeconstructedURL) { // if no host header was set, let's use the url host guard !self.contains(name: "host"), - var host = url.connectionTarget.host + var host = url.connectionTarget.host else { return } // if the request uses a non-default port, we need to add it after the host if let port = url.connectionTarget.port, - port != url.scheme.defaultPort { + port != url.scheme.defaultPort + { host += ":\(port)" } self.add(name: "host", value: host) diff --git a/Sources/AsyncHTTPClient/SSLContextCache.swift b/Sources/AsyncHTTPClient/SSLContextCache.swift index 660a04942..599003e56 100644 --- a/Sources/AsyncHTTPClient/SSLContextCache.swift +++ b/Sources/AsyncHTTPClient/SSLContextCache.swift @@ -25,30 +25,38 @@ final class SSLContextCache { } extension SSLContextCache { - func sslContext(tlsConfiguration: TLSConfiguration, - eventLoop: EventLoop, - logger: Logger) -> EventLoopFuture { + func sslContext( + tlsConfiguration: TLSConfiguration, + eventLoop: EventLoop, + logger: Logger + ) -> EventLoopFuture { let eqTLSConfiguration = BestEffortHashableTLSConfiguration(wrapping: tlsConfiguration) let sslContext = self.lock.withLock { self.sslContextCache.find(key: eqTLSConfiguration) } if let sslContext = sslContext { - logger.trace("found SSL context in cache", - metadata: ["ahc-tls-config": "\(tlsConfiguration)"]) + logger.trace( + "found SSL context in cache", + metadata: ["ahc-tls-config": "\(tlsConfiguration)"] + ) return eventLoop.makeSucceededFuture(sslContext) } - logger.trace("creating new SSL context", - metadata: ["ahc-tls-config": "\(tlsConfiguration)"]) + logger.trace( + "creating new SSL context", + metadata: ["ahc-tls-config": "\(tlsConfiguration)"] + ) let newSSLContext = self.offloadQueue.asyncWithFuture(eventLoop: eventLoop) { try NIOSSLContext(configuration: tlsConfiguration) } newSSLContext.whenSuccess { (newSSLContext: NIOSSLContext) -> Void in self.lock.withLock { () -> Void in - self.sslContextCache.append(key: eqTLSConfiguration, - value: newSSLContext) + self.sslContextCache.append( + key: eqTLSConfiguration, + value: newSSLContext + ) } } diff --git a/Sources/AsyncHTTPClient/Singleton.swift b/Sources/AsyncHTTPClient/Singleton.swift index 149f7586f..0ddf1bc40 100644 --- a/Sources/AsyncHTTPClient/Singleton.swift +++ b/Sources/AsyncHTTPClient/Singleton.swift @@ -20,7 +20,7 @@ extension HTTPClient { /// - `EventLoopGroup` is ``HTTPClient/defaultEventLoopGroup`` (matching the platform default) /// - logging is disabled public static var shared: HTTPClient { - return globallySharedHTTPClient + globallySharedHTTPClient } } diff --git a/Sources/AsyncHTTPClient/StringConvertibleInstances.swift b/Sources/AsyncHTTPClient/StringConvertibleInstances.swift index f75fb0d87..61d4b067a 100644 --- a/Sources/AsyncHTTPClient/StringConvertibleInstances.swift +++ b/Sources/AsyncHTTPClient/StringConvertibleInstances.swift @@ -14,6 +14,6 @@ extension HTTPClient.EventLoopPreference: CustomStringConvertible { public var description: String { - return "\(self.preference)" + "\(self.preference)" } } diff --git a/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift b/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift new file mode 100644 index 000000000..25f1225e0 --- /dev/null +++ b/Sources/AsyncHTTPClient/StructuredConcurrencyHelpers.swift @@ -0,0 +1,83 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the AsyncHTTPClient open source project +// +// Copyright (c) 2025 Apple Inc. and the AsyncHTTPClient project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// +// swift-format-ignore +// Note: Whitespace changes are used to workaround compiler bug +// https://github.com/swiftlang/swift/issues/79285 + +#if compiler(>=6.0) +@inlinable +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func asyncDo( + isolation: isolated (any Actor)? = #isolation, + // DO NOT FIX THE WHITESPACE IN THE NEXT LINE UNTIL 5.10 IS UNSUPPORTED + // https://github.com/swiftlang/swift/issues/79285 + _ body: () async throws -> sending R, finally: sending @escaping ((any Error)?) async throws -> Void) async throws -> sending R { + let result: R + do { + result = try await body() + } catch { + // `body` failed, we need to invoke `finally` with the `error`. + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(error) + }.value + throw error + } + + // `body` succeeded, we need to invoke `finally` with `nil` (no error). + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(nil) + }.value + return result +} +#else +@inlinable +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func asyncDo( + _ body: () async throws -> R, + finally: @escaping @Sendable ((any Error)?) async throws -> Void +) async throws -> R { + let result: R + do { + result = try await body() + } catch { + // `body` failed, we need to invoke `finally` with the `error`. + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(error) + }.value + throw error + } + + // `body` succeeded, we need to invoke `finally` with `nil` (no error). + + // This _looks_ unstructured but isn't really because we unconditionally always await the return. + // We need to have an uncancelled task here to assure this is actually running in case we hit a + // cancellation error. + try await Task { + try await finally(nil) + }.value + return result +} +#endif diff --git a/Sources/AsyncHTTPClient/UnsafeTransfer.swift b/Sources/AsyncHTTPClient/UnsafeTransfer.swift deleted file mode 100644 index ea5af56da..000000000 --- a/Sources/AsyncHTTPClient/UnsafeTransfer.swift +++ /dev/null @@ -1,29 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) 2022 Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -/// ``UnsafeMutableTransferBox`` can be used to make non-`Sendable` values `Sendable` and mutable. -/// It can be used to capture local mutable values in a `@Sendable` closure and mutate them from within the closure. -/// As the name implies, the usage of this is unsafe because it disables the sendable checking of the compiler and does not add any synchronisation. -@usableFromInline -final class UnsafeMutableTransferBox { - @usableFromInline - var wrappedValue: Wrapped - - @inlinable - init(_ wrappedValue: Wrapped) { - self.wrappedValue = wrappedValue - } -} - -extension UnsafeMutableTransferBox: @unchecked Sendable {} diff --git a/Sources/AsyncHTTPClient/Utils.swift b/Sources/AsyncHTTPClient/Utils.swift index f8618ea17..985755143 100644 --- a/Sources/AsyncHTTPClient/Utils.swift +++ b/Sources/AsyncHTTPClient/Utils.swift @@ -18,10 +18,10 @@ import NIOCore /// /// ``HTTPClientCopyingDelegate`` discards most parts of a HTTP response, but streams the body /// to the `chunkHandler` provided on ``init(chunkHandler:)``. This is mostly useful for testing. -public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { +public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate, Sendable { public typealias Response = Void - let chunkHandler: (ByteBuffer) -> EventLoopFuture + let chunkHandler: @Sendable (ByteBuffer) -> EventLoopFuture @preconcurrency public init(chunkHandler: @Sendable @escaping (ByteBuffer) -> EventLoopFuture) { @@ -29,11 +29,11 @@ public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { } public func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - return self.chunkHandler(buffer) + self.chunkHandler(buffer) } public func didFinishRequest(task: HTTPClient.Task) throws { - return () + () } } @@ -44,7 +44,12 @@ public final class HTTPClientCopyingDelegate: HTTPClientResponseDelegate { /// https://forums.swift.org/t/support-debug-only-code/11037 for a discussion. @inlinable internal func debugOnly(_ body: () -> Void) { - assert({ body(); return true }()) + assert( + { + body() + return true + }() + ) } extension BidirectionalCollection where Element: Equatable { @@ -61,8 +66,8 @@ extension BidirectionalCollection where Element: Equatable { guard self[ourIdx] == suffix[suffixIdx] else { return false } } guard suffixIdx == suffix.startIndex else { - return false // Exhausted self, but 'suffix' has elements remaining. + return false // Exhausted self, but 'suffix' has elements remaining. } - return true // Exhausted 'other' without finding a mismatch. + return true // Exhausted 'other' without finding a mismatch. } } diff --git a/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c b/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c index 5dfdc08a5..6342da89f 100644 --- a/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c +++ b/Sources/CAsyncHTTPClient/CAsyncHTTPClient.c @@ -31,7 +31,7 @@ bool swiftahc_cshims_strptime(const char * string, const char * format, struct t bool swiftahc_cshims_strptime_l(const char * string, const char * format, struct tm * result, void * locale) { // The pointer cast is fine as long we make sure it really points to a locale_t. -#ifdef __musl__ +#if defined(__musl__) || defined(__ANDROID__) const char * firstNonProcessed = strptime(string, format, result); #else const char * firstNonProcessed = strptime_l(string, format, result, (locale_t)locale); diff --git a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift index a30a8cf91..56a08b852 100644 --- a/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift +++ b/Tests/AsyncHTTPClientTests/AsyncAwaitEndToEndTests.swift @@ -12,13 +12,16 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore +import NIOFoundationCompat +import NIOHTTP1 import NIOPosix import NIOSSL import XCTest +@testable import AsyncHTTPClient + private func makeDefaultHTTPClient( eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .singleton ) -> HTTPClient { @@ -65,12 +68,16 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + XCTAssertEqual(response.url?.absoluteString, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) } @@ -85,12 +92,16 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } + XCTAssertEqual(response.url?.absoluteString, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) } @@ -107,13 +118,17 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.method = .POST request.body = .bytes(ByteBuffer(string: "1234")) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], ["4"]) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect(upTo: 1024) - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } } @@ -129,13 +144,17 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.method = .POST request.body = .bytes(AnySendableSequence("1234".utf8), length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect(upTo: 1024) - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } } @@ -151,13 +170,17 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.method = .POST request.body = .bytes(AnySendableCollection("1234".utf8), length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect(upTo: 1024) - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } } @@ -173,17 +196,81 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.method = .POST request.body = .bytes(ByteBuffer(string: "1234").readableBytesView) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], ["4"]) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect(upTo: 1024) - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } } + struct AsyncSequenceByteBufferGenerator: AsyncSequence, Sendable, AsyncIteratorProtocol { + typealias Element = ByteBuffer + + let chunkSize: Int + let totalChunks: Int + let buffer: ByteBuffer + var chunksGenerated: Int = 0 + + init(chunkSize: Int, totalChunks: Int) { + self.chunkSize = chunkSize + self.totalChunks = totalChunks + self.buffer = ByteBuffer(repeating: 1, count: self.chunkSize) + } + + mutating func next() async throws -> ByteBuffer? { + guard self.chunksGenerated < self.totalChunks else { return nil } + + self.chunksGenerated += 1 + return self.buffer + } + + func makeAsyncIterator() -> AsyncSequenceByteBufferGenerator { + self + } + } + + func testEchoStreamThatHas3GBInTotal() async throws { + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + let bin = HTTPBin(.http1_1()) { _ in HTTPEchoHandler() } + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let client: HTTPClient = makeDefaultHTTPClient(eventLoopGroupProvider: .shared(eventLoopGroup)) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + + var request = HTTPClientRequest(url: "http://localhost:\(bin.port)/") + request.method = .POST + + let sequence = AsyncSequenceByteBufferGenerator( + chunkSize: 4_194_304, // 4MB chunk + totalChunks: 768 // Total = 3GB + ) + request.body = .stream(sequence, length: .unknown) + + let response: HTTPClientResponse = try await client.execute( + request, + deadline: .now() + .seconds(30), + logger: logger + ) + XCTAssertEqual(response.headers["content-length"], []) + + var receivedBytes: Int64 = 0 + for try await part in response.body { + receivedBytes += Int64(part.readableBytes) + } + XCTAssertEqual(receivedBytes, 3_221_225_472) // 3GB + } + func testPostWithAsyncSequenceOfByteBuffers() { XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) { _ in HTTPEchoHandler() } @@ -193,19 +280,26 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .stream([ - ByteBuffer(string: "1"), - ByteBuffer(string: "2"), - ByteBuffer(string: "34"), - ].async, length: .unknown) + request.body = .stream( + [ + ByteBuffer(string: "1"), + ByteBuffer(string: "2"), + ByteBuffer(string: "34"), + ].async, + length: .unknown + ) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect(upTo: 1024) - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } } @@ -221,13 +315,17 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.method = .POST request.body = .stream("1234".utf8.async, length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response.body.collect(upTo: 1024) - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } } @@ -244,9 +342,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let streamWriter = AsyncSequenceWriter() request.body = .stream(streamWriter, length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) let fragments = [ @@ -257,16 +357,20 @@ final class AsyncAwaitEndToEndTests: XCTestCase { var bodyIterator = response.body.makeAsyncIterator() for expectedFragment in fragments { streamWriter.write(expectedFragment) - guard let actualFragment = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let actualFragment = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(expectedFragment, actualFragment) } streamWriter.end() - guard let lastResult = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let lastResult = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(lastResult, nil) } } @@ -283,9 +387,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let streamWriter = AsyncSequenceWriter() request.body = .stream(streamWriter, length: .unknown) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response.headers["content-length"], []) let fragments = [ @@ -297,16 +403,20 @@ final class AsyncAwaitEndToEndTests: XCTestCase { var bodyIterator = response.body.makeAsyncIterator() for expectedFragment in fragments { streamWriter.write(expectedFragment) - guard let actualFragment = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let actualFragment = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(expectedFragment, actualFragment) } streamWriter.end() - guard let lastResult = await XCTAssertNoThrowWithResult( - try await bodyIterator.next() - ) else { return } + guard + let lastResult = await XCTAssertNoThrowWithResult( + try await bodyIterator.next() + ) + else { return } XCTAssertEqual(lastResult, nil) } } @@ -380,7 +490,10 @@ final class AsyncAwaitEndToEndTests: XCTestCase { // a race between deadline and connect timer can result in either error. // If closing happens really fast we might shutdown the pipeline before we fail the request. // If the pipeline is closed we may receive a `.remoteConnectionClosed`. - XCTAssertTrue([.deadlineExceeded, .connectTimeout, .remoteConnectionClosed].contains(error), "unexpected error \(error)") + XCTAssertTrue( + [.deadlineExceeded, .connectTimeout, .remoteConnectionClosed].contains(error), + "unexpected error \(error)" + ) } } } @@ -404,12 +517,17 @@ final class AsyncAwaitEndToEndTests: XCTestCase { // a race between deadline and connect timer can result in either error. // If closing happens really fast we might shutdown the pipeline before we fail the request. // If the pipeline is closed we may receive a `.remoteConnectionClosed`. - XCTAssertTrue([.deadlineExceeded, .connectTimeout, .remoteConnectionClosed].contains(error), "unexpected error \(error)") + XCTAssertTrue( + [.deadlineExceeded, .connectTimeout, .remoteConnectionClosed].contains(error), + "unexpected error \(error)" + ) } } } func testConnectTimeout() { + let serverGroup = self.serverGroup! + let clientGroup = self.clientGroup! XCTAsyncTest(timeout: 60) { #if os(Linux) // 198.51.100.254 is reserved for documentation only and therefore should not accept any TCP connection @@ -426,7 +544,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - let serverChannel = try await ServerBootstrap(group: self.serverGroup) + let serverChannel = try await ServerBootstrap(group: serverGroup) .serverChannelOption(ChannelOptions.backlog, value: 1) .serverChannelOption(ChannelOptions.autoRead, value: false) .bind(host: "127.0.0.1", port: 0) @@ -435,7 +553,7 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertNoThrow(try serverChannel.close().wait()) } let port = serverChannel.localAddress!.port! - let firstClientChannel = try await ClientBootstrap(group: self.serverGroup) + let firstClientChannel = try await ClientBootstrap(group: serverGroup) .connect(host: "127.0.0.1", port: port) .get() defer { @@ -444,8 +562,10 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let url = "http://localhost:\(port)/get" #endif - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150)))) + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150))) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown()) @@ -471,16 +591,19 @@ final class AsyncAwaitEndToEndTests: XCTestCase { /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) let configuration = TLSConfiguration.makeServerConfiguration( certificateChain: try NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, - privateKey: .file(keyPath) + privateKey: .privateKey(key) ) let sslContext = try NIOSSLContext(configuration: configuration) let serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try serverGroup.syncShutdownGracefully()) } let server = ServerBootstrap(group: serverGroup) .childChannelInitializer { channel in - channel.pipeline.addHandler(NIOSSLServerHandler(context: sslContext)) + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } } let serverChannel = try await server.bind(host: "localhost", port: 0).get() defer { XCTAssertNoThrow(try serverChannel.close().wait()) } @@ -492,7 +615,8 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let localClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: config) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } let request = HTTPClientRequest(url: "https://localhost:\(port)") - await XCTAssertThrowsError(try await localClient.execute(request, deadline: .now() + .seconds(2))) { error in + await XCTAssertThrowsError(try await localClient.execute(request, deadline: .now() + .seconds(2))) { + error in #if canImport(Network) guard let nwTLSError = error as? HTTPClient.NWTLSError else { XCTFail("could not cast \(error) of type \(type(of: error)) to \(HTTPClient.NWTLSError.self)") @@ -501,7 +625,8 @@ final class AsyncAwaitEndToEndTests: XCTestCase { XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") #else guard let sslError = error as? NIOSSLError, - case .handshakeFailed(.sslError) = sslError else { + case .handshakeFailed(.sslError) = sslError + else { XCTFail("unexpected error \(error)") return } @@ -512,39 +637,40 @@ final class AsyncAwaitEndToEndTests: XCTestCase { func testDnsOverride() { XCTAsyncTest(timeout: 5) { - /// key + cert was created with the following code (depends on swift-certificates) - /// ``` - /// import X509 - /// import CryptoKit - /// import Foundation - /// - /// let privateKey = P384.Signing.PrivateKey() - /// let name = try DistinguishedName { - /// OrganizationName("Self Signed") - /// CommonName("localhost") - /// } - /// let certificate = try Certificate( - /// version: .v3, - /// serialNumber: .init(), - /// publicKey: .init(privateKey.publicKey), - /// notValidBefore: Date(), - /// notValidAfter: Date().advanced(by: 365 * 24 * 3600), - /// issuer: name, - /// subject: name, - /// signatureAlgorithm: .ecdsaWithSHA384, - /// extensions: try .init { - /// SubjectAlternativeNames([.dnsName("example.com")]) - /// try ExtendedKeyUsage([.serverAuth]) - /// }, - /// issuerPrivateKey: .init(privateKey) - /// ) - /// ``` + // key + cert was created with the following code (depends on swift-certificates) + // ``` + // import X509 + // import CryptoKit + // import Foundation + // + // let privateKey = P384.Signing.PrivateKey() + // let name = try DistinguishedName { + // OrganizationName("Self Signed") + // CommonName("localhost") + // } + // let certificate = try Certificate( + // version: .v3, + // serialNumber: .init(), + // publicKey: .init(privateKey.publicKey), + // notValidBefore: Date(), + // notValidAfter: Date().advanced(by: 365 * 24 * 3600), + // issuer: name, + // subject: name, + // signatureAlgorithm: .ecdsaWithSHA384, + // extensions: try .init { + // SubjectAlternativeNames([.dnsName("example.com")]) + // try ExtendedKeyUsage([.serverAuth]) + // }, + // issuerPrivateKey: .init(privateKey) + // ) + // ``` let certPath = Bundle.module.path(forResource: "example.com.cert", ofType: "pem")! let keyPath = Bundle.module.path(forResource: "example.com.private-key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) let localhostCert = try NIOSSLCertificate.fromPEMFile(certPath) let configuration = TLSConfiguration.makeServerConfiguration( certificateChain: localhostCert.map { .certificate($0) }, - privateKey: .file(keyPath) + privateKey: .privateKey(key) ) let bin = HTTPBin(.http2(tlsConfiguration: configuration)) defer { XCTAssertNoThrow(try bin.shutdown()) } @@ -561,7 +687,9 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let localClient = HTTPClient(eventLoopGroupProvider: .singleton, configuration: config) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } let request = HTTPClientRequest(url: "https://example.com:\(bin.port)/echohostheader") - let response = await XCTAssertNoThrowWithResult(try await localClient.execute(request, deadline: .now() + .seconds(2))) + let response = await XCTAssertNoThrowWithResult( + try await localClient.execute(request, deadline: .now() + .seconds(2)) + ) XCTAssertEqual(response?.status, .ok) XCTAssertEqual(response?.version, .http2) var body = try await response?.body.collect(upTo: 1024) @@ -576,14 +704,34 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let client = makeDefaultHTTPClient() defer { XCTAssertNoThrow(try client.syncShutdown()) } let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) - let request = HTTPClientRequest(url: "") // invalid URL + let request = HTTPClientRequest(url: "") // invalid URL - await XCTAssertThrowsError(try await client.execute(request, deadline: .now() + .seconds(2), logger: logger)) { + await XCTAssertThrowsError( + try await client.execute(request, deadline: .now() + .seconds(2), logger: logger) + ) { XCTAssertEqual($0 as? HTTPClientError, .invalidURL) } } } + func testInsanelyHighConcurrentHTTP1ConnectionLimitDoesNotCrash() async throws { + let bin = HTTPBin(.http1_1(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + var httpClientConfig = HTTPClient.Configuration() + httpClientConfig.connectionPool = .init( + idleTimeout: .hours(1), + concurrentHTTP1ConnectionsPerHostSoftLimit: Int.max + ) + httpClientConfig.timeout = .init(connect: .seconds(10), read: .seconds(100), write: .seconds(100)) + + let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: httpClientConfig) + defer { XCTAssertNoThrow(try httpClient.syncShutdown()) } + + let request = HTTPClientRequest(url: "http://localhost:\(bin.port)") + _ = try await httpClient.execute(request, deadline: .now() + .seconds(2)) + } + func testRedirectChangesHostHeader() { XCTAsyncTest { let bin = HTTPBin(.http2(compress: false)) @@ -592,18 +740,28 @@ final class AsyncAwaitEndToEndTests: XCTestCase { defer { XCTAssertNoThrow(try client.syncShutdown()) } let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://127.0.0.1:\(bin.port)/redirect/target") - request.headers.replaceOrAdd(name: "X-Target-Redirect-URL", value: "https://localhost:\(bin.port)/echohostheader") + let redirectURL = "https://localhost:\(bin.port)/echohostheader" + request.headers.replaceOrAdd( + name: "X-Target-Redirect-URL", + value: redirectURL + ) - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { + return + } + guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect(upTo: 1024)) else { return } - guard let body = await XCTAssertNoThrowWithResult(try await response.body.collect(upTo: 1024)) else { return } var maybeRequestInfo: RequestInfo? XCTAssertNoThrow(maybeRequestInfo = try JSONDecoder().decode(RequestInfo.self, from: body)) guard let requestInfo = maybeRequestInfo else { return } + XCTAssertEqual(response.url?.absoluteString, redirectURL) + XCTAssertEqual(response.history.map(\.request.url), [request.url, redirectURL]) XCTAssertEqual(response.status, .ok) XCTAssertEqual(response.version, .http2) XCTAssertEqual(requestInfo.data, "localhost:\(bin.port)") @@ -646,28 +804,39 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/") request.method = .POST - request.body = .stream([ - ByteBuffer(string: "1"), - ByteBuffer(string: "2"), - ByteBuffer(string: "34"), - ].async, length: .unknown) + request.body = .stream( + [ + ByteBuffer(string: "1"), + ByteBuffer(string: "2"), + ByteBuffer(string: "34"), + ].async, + length: .unknown + ) - guard let response1 = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response1 = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response1.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response1.body.collect(upTo: 1024) - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response1.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) - guard let response2 = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response2 = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } XCTAssertEqual(response2.headers["content-length"], []) - guard let body = await XCTAssertNoThrowWithResult( - try await response2.body.collect(upTo: 1024) - ) else { return } + guard + let body = await XCTAssertNoThrowWithResult( + try await response2.body.collect(upTo: 1024) + ) + else { return } XCTAssertEqual(body, ByteBuffer(string: "1234")) } } @@ -706,9 +875,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.headers.add(name: weirdAllowedFieldName, value: "present") // This should work fine. - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } @@ -725,7 +896,9 @@ final class AsyncAwaitEndToEndTests: XCTestCase { var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") request.headers.add(name: forbiddenFieldName, value: "present") - await XCTAssertThrowsError(try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)) { error in + await XCTAssertThrowsError( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) { error in XCTAssertEqual(error as? HTTPClientError, .invalidHeaderFieldNames([forbiddenFieldName])) } } @@ -749,15 +922,18 @@ final class AsyncAwaitEndToEndTests: XCTestCase { let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) // We reject all ASCII control characters except HTAB and tolerate everything else. - let weirdAllowedFieldValue = "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" + let weirdAllowedFieldValue = + "!\" \t#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~" var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") request.headers.add(name: "Weird-Value", value: weirdAllowedFieldValue) // This should work fine. - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } @@ -774,7 +950,9 @@ final class AsyncAwaitEndToEndTests: XCTestCase { var request = HTTPClientRequest(url: "https://localhost:\(bin.port)/get") request.headers.add(name: "Weird-Value", value: forbiddenFieldValue) - await XCTAssertThrowsError(try await client.execute(request, deadline: .now() + .seconds(10), logger: logger)) { error in + await XCTAssertThrowsError( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) { error in XCTAssertEqual(error as? HTTPClientError, .invalidHeaderFieldValues([forbiddenFieldValue])) } } @@ -787,9 +965,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { request.headers.add(name: "Weird-Value", value: evenWeirderAllowedValue) // This should work fine. - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } @@ -806,9 +986,11 @@ final class AsyncAwaitEndToEndTests: XCTestCase { defer { XCTAssertNoThrow(try client.syncShutdown()) } let request = try HTTPClient.Request(url: "https://localhost:\(bin.port)/get") - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request: request).get() - ) else { + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request: request).get() + ) + else { return } @@ -825,13 +1007,15 @@ final class AsyncAwaitEndToEndTests: XCTestCase { defer { XCTAssertNoThrow(try client.syncShutdown()) } let logger = Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) let request = HTTPClientRequest(url: "https://localhost:\(bin.port)/content-length-without-body") - guard let response = await XCTAssertNoThrowWithResult( - try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) - ) else { return } + guard + let response = await XCTAssertNoThrowWithResult( + try await client.execute(request, deadline: .now() + .seconds(10), logger: logger) + ) + else { return } await XCTAssertThrowsError( try await response.body.collect(upTo: 3) ) { - XCTAssertEqualTypeAndValue($0, NIOTooManyBytesError()) + XCTAssertEqualTypeAndValue($0, NIOTooManyBytesError(maxBytes: 3)) } } } diff --git a/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift b/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift index 147b24dca..4a5c8d486 100644 --- a/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift +++ b/Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift @@ -17,7 +17,7 @@ import NIOCore /// ``AsyncSequenceWriter`` is `Sendable` because its state is protected by a Lock @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -final class AsyncSequenceWriter: AsyncSequence, @unchecked Sendable { +final class AsyncSequenceWriter: AsyncSequence, @unchecked Sendable { typealias AsyncIterator = Iterator struct Iterator: AsyncIteratorProtocol { @@ -33,7 +33,7 @@ final class AsyncSequenceWriter: AsyncSequence, @unchecked Sendable { } func makeAsyncIterator() -> Iterator { - return Iterator(self) + Iterator(self) } private enum State { @@ -117,7 +117,9 @@ final class AsyncSequenceWriter: AsyncSequence, @unchecked Sendable { case .waiting: let state = self._state self.lock.unlock() - preconditionFailure("Expected that there is always only one concurrent call to next. Invalid state: \(state)") + preconditionFailure( + "Expected that there is always only one concurrent call to next. Invalid state: \(state)" + ) } } diff --git a/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift b/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift index 79c304fc2..962791334 100644 --- a/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift +++ b/Tests/AsyncHTTPClientTests/ConnectionPoolSizeConfigValueIsRespectedTests.swift @@ -14,9 +14,6 @@ import AsyncHTTPClient import Atomics -#if canImport(Network) -import Network -#endif import Logging import NIOConcurrencyHelpers import NIOCore @@ -29,6 +26,10 @@ import NIOTestUtils import NIOTransportServices import XCTest +#if canImport(Network) +import Network +#endif + final class ConnectionPoolSizeConfigValueIsRespectedTests: XCTestCaseHTTPClientTestsBaseClass { func testConnectionPoolSizeConfigValueIsRespected() { let numberOfRequestsPerThread = 1000 diff --git a/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift b/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift index 5e7a1a9bc..5cc35bce8 100644 --- a/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift +++ b/Tests/AsyncHTTPClientTests/EmbeddedChannel+HTTPConvenience.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded import NIOHTTP1 import NIOHTTP2 +@testable import AsyncHTTPClient + extension EmbeddedChannel { public func receiveHeadAndVerify(_ verify: (HTTPRequestHead) throws -> Void = { _ in }) throws { let part = try self.readOutbound(as: HTTPClientRequestPart.self) @@ -58,7 +59,7 @@ extension EmbeddedChannel { } struct HTTP1TestTools { - let connection: HTTP1Connection + let connection: HTTP1Connection.SendableView let connectionDelegate: MockConnectionDelegate let readEventHandler: ReadEventHitHandler let logger: Logger @@ -86,8 +87,8 @@ extension EmbeddedChannel { let decoder = try self.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) let encoder = try self.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self) - let removeDecoderFuture = self.pipeline.removeHandler(decoder) - let removeEncoderFuture = self.pipeline.removeHandler(encoder) + let removeDecoderFuture = self.pipeline.syncOperations.removeHandler(decoder) + let removeEncoderFuture = self.pipeline.syncOperations.removeHandler(encoder) self.embeddedEventLoop.run() @@ -95,7 +96,7 @@ extension EmbeddedChannel { try removeEncoderFuture.wait() return .init( - connection: connection, + connection: connection.sendableView, connectionDelegate: connectionDelegate, readEventHandler: readEventHandler, logger: logger @@ -111,6 +112,6 @@ public struct HTTP1EmbeddedChannelError: Error, Hashable, CustomStringConvertibl } public var description: String { - return self.reason + self.reason } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift index f6a2840d9..0d871b7dc 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ClientChannelHandlerTests.swift @@ -12,13 +12,15 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging +import NIOConcurrencyHelpers import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class HTTP1ClientChannelHandlerTests: XCTestCase { func testResponseBackpressure() { let embedded = EmbeddedChannel() @@ -32,27 +34,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -113,22 +123,30 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 100) { writer in - testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 100) { writer in + testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } // the handler only writes once the channel is writable @@ -143,12 +161,14 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { testWriter.writabilityChanged(true) embedded.pipeline.fireChannelWritabilityChanged() - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .POST) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - XCTAssertEqual($0.headers.first(name: "content-length"), "100") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "100") + } + ) // the next body write will be executed once we tick the el. before we make the channel // unwritable @@ -162,9 +182,11 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { embedded.embeddedEventLoop.run() - XCTAssertNoThrow(try embedded.receiveBodyAndVerify { - XCTAssertEqual($0.readableBytes, 2) - }) + XCTAssertNoThrow( + try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + } + ) XCTAssertEqual(testWriter.written, index + 1) @@ -201,24 +223,28 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) XCTAssertTrue(embedded.isActive) @@ -247,27 +273,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -299,27 +333,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -345,25 +387,33 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 10) { writer in - // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. - embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) - return testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + return testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.isWritable = true @@ -376,42 +426,108 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { } } - func testIdleWriteTimeoutWritabilityChanged() { + func testIdleWriteTimeoutRaceToEnd() { let embedded = EmbeddedChannel() - let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) var maybeTestUtils: HTTP1TestTools? XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 10) { writer in - embedded.isWritable = false - embedded.pipeline.fireChannelWritabilityChanged() - // This should not trigger any errors or timeouts, because the timer isn't running - // as the channel is not writable. - embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream { _ in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + let scheduled = embedded.embeddedEventLoop.flatScheduleTask(in: .milliseconds(2)) { + embedded.embeddedEventLoop.makeSucceededVoidFuture() + } + return scheduled.futureResult + } + ) + ) - // Now that the channel will become writable, this should trigger a timeout. - embedded.isWritable = true - embedded.pipeline.fireChannelWritabilityChanged() - embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + let delegate = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(5)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - return testWriter.start(writer: writer) - })) + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + testUtils.connection.executeRequest(requestBag) + let expectedHeaders: HTTPHeaders = ["host": "localhost", "Transfer-Encoding": "chunked"] + XCTAssertEqual( + try embedded.readOutbound(as: HTTPClientRequestPart.self), + .head(HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: expectedHeaders)) + ) + + // change the writability to false. + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.run() + + // let the writer, write an end (while writability is false) + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) + } + + func testIdleWriteTimeoutWritabilityChanged() { + let embedded = EmbeddedChannel() + let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + // This should not trigger any errors or timeouts, because the timer isn't running + // as the channel is not writable. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) + + // Now that the channel will become writable, this should trigger a timeout. + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + return testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.isWritable = true @@ -432,22 +548,30 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 2) { writer in - return testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 2) { writer in + testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.isWritable = true @@ -478,27 +602,35 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "50")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "50")]) + ) XCTAssertEqual(testUtils.readEventHandler.readHitCounter, 0) embedded.read() @@ -549,7 +681,12 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } - XCTAssertNoThrow(try embedded.pipeline.syncOperations.addHandler(FailWriteHandler(), position: .after(testUtils.readEventHandler))) + XCTAssertNoThrow( + try embedded.pipeline.syncOperations.addHandler( + FailWriteHandler(), + position: .after(testUtils.readEventHandler) + ) + ) let logger = Logger(label: "test") @@ -559,16 +696,20 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) - guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to be able to create a request bag") + } embedded.isWritable = false XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) @@ -595,22 +736,30 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 10) { writer in - testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } XCTAssertNoThrow(try embedded.pipeline.addHandler(FailEndHandler(), position: .first).wait()) @@ -618,12 +767,14 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { // Execute the request and we'll receive the head. testWriter.writabilityChanged(true) testUtils.connection.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .POST) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - XCTAssertEqual($0.headers.first(name: "content-length"), "10") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "10") + } + ) // We're going to immediately send the response head and end. let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) XCTAssertNoThrow(try embedded.writeInbound(HTTPClientResponsePart.head(responseHead))) @@ -639,9 +790,11 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { embedded.embeddedEventLoop.run() XCTAssertEqual(testWriter.written, 5) for _ in 0..<5 { - XCTAssertNoThrow(try embedded.receiveBodyAndVerify { - XCTAssertEqual($0.readableBytes, 2) - }) + XCTAssertNoThrow( + try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + } + ) } embedded.embeddedEventLoop.run() @@ -672,49 +825,117 @@ class HTTP1ClientChannelHandlerTests: XCTestCase { backgroundLogger: Logger(label: "no-op", factory: SwiftLogNoOpLogHandler.init), connectionIdLoggerMetadata: "test connection" ) - let channel = EmbeddedChannel(handlers: [ - ChangeWritabilityOnFlush(), - handler, - ], loop: eventLoop) + let channel = EmbeddedChannel( + handlers: [ + ChangeWritabilityOnFlush(), + handler, + ], + loop: eventLoop + ) try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait() - let request = MockHTTPExecutableRequest() // non empty body is important to trigger this bug as we otherwise finish the request in a single flush - request.requestFramingMetadata.body = .fixedSize(1) - request.raiseErrorIfUnimplementedMethodIsCalled = false + let request = MockHTTPExecutableRequest( + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .fixedSize(1)), + raiseErrorIfUnimplementedMethodIsCalled: false + ) channel.writeAndFlush(request, promise: nil) XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) } + + func testIdleWriteTimeoutOutsideOfRunningState() { + let embedded = EmbeddedChannel() + var maybeTestUtils: HTTP1TestTools? + XCTAssertNoThrow(maybeTestUtils = try embedded.setupHTTP1Connection()) + print("pipeline", embedded.pipeline) + guard let testUtils = maybeTestUtils else { return XCTFail("Expected connection setup works") } + + var maybeRequest: HTTPClient.Request? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/")) + guard var request = maybeRequest else { return XCTFail("Expected to be able to create a request") } + + // start a request stream we'll never write to + let streamPromise = embedded.eventLoop.makePromise(of: Void.self) + let streamCallback = { @Sendable (streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture in + streamPromise.futureResult + } + request.body = .init(contentLength: nil, stream: streamCallback) + + let accumulator = ResponseAccumulator(request: request) + var maybeRequestBag: RequestBag? + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: testUtils.logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests( + idleReadTimeout: .milliseconds(10), + idleWriteTimeout: .milliseconds(2) + ), + delegate: accumulator + ) + ) + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + + testUtils.connection.executeRequest(requestBag) + + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) + + // close the pipeline to simulate a server-side close + // note this happens before we write so the idle write timeout is still running + try! embedded.pipeline.close().wait() + + // advance time to trigger the idle write timeout + // and ensure that the state machine can tolerate this + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(250)) + } } -class TestBackpressureWriter { +final class TestBackpressureWriter: Sendable { let eventLoop: EventLoop let parts: Int var finishFuture: EventLoopFuture { self.finishPromise.futureResult } private let finishPromise: EventLoopPromise - private(set) var written: Int = 0 - private var channelIsWritable: Bool = false + private struct State { + var written = 0 + var channelIsWritable = false + } + + var written: Int { + self.state.value.written + } + + private let state: NIOLoopBoundBox init(eventLoop: EventLoop, parts: Int) { self.eventLoop = eventLoop self.parts = parts - + self.state = .makeBoxSendingValue(State(), eventLoop: eventLoop) self.finishPromise = eventLoop.makePromise(of: Void.self) } func start(writer: HTTPClient.Body.StreamWriter, expectedErrors: [HTTPClientError] = []) -> EventLoopFuture { + @Sendable func recursive() { XCTAssert(self.eventLoop.inEventLoop) - XCTAssert(self.channelIsWritable) - if self.written == self.parts { + XCTAssert(self.state.value.channelIsWritable) + if self.state.value.written == self.parts { self.finishPromise.succeed(()) } else { self.eventLoop.execute { let future = writer.write(.byteBuffer(.init(bytes: [0, 1]))) - self.written += 1 + self.state.value.written += 1 future.whenComplete { result in switch result { case .success: @@ -741,14 +962,14 @@ class TestBackpressureWriter { } func writabilityChanged(_ newValue: Bool) { - self.channelIsWritable = newValue + self.state.value.channelIsWritable = newValue } } -class ResponseBackpressureDelegate: HTTPClientResponseDelegate { +final class ResponseBackpressureDelegate: HTTPClientResponseDelegate { typealias Response = Void - enum State { + enum State: Sendable { case consuming(EventLoopPromise) case waitingForRemote(CircularBuffer>) case buffering((ByteBuffer?, EventLoopPromise)?) @@ -756,40 +977,42 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } let eventLoop: EventLoop - private var state: State = .buffering(nil) + private let state: NIOLoopBoundBox init(eventLoop: EventLoop) { self.eventLoop = eventLoop - - self.state = .consuming(self.eventLoop.makePromise(of: Void.self)) + self.state = .makeBoxSendingValue(.consuming(eventLoop.makePromise(of: Void.self)), eventLoop: eventLoop) } func next() -> EventLoopFuture { - switch self.state { + switch self.state.value { case .consuming(let backpressurePromise): var promiseBuffer = CircularBuffer>() let newPromise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(newPromise) - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) backpressurePromise.succeed(()) return newPromise.futureResult case .waitingForRemote(var promiseBuffer): - assert(!promiseBuffer.isEmpty, "assert expected to be waiting if we have at least one promise in the buffer") + assert( + !promiseBuffer.isEmpty, + "assert expected to be waiting if we have at least one promise in the buffer" + ) let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(promise) - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) return promise.futureResult case .buffering(.none): var promiseBuffer = CircularBuffer>() let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) promiseBuffer.append(promise) - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) return promise.futureResult case .buffering(.some((let buffer, let promise))): - self.state = .buffering(nil) + self.state.value = .buffering(nil) promise.succeed(()) return self.eventLoop.makeSucceededFuture(buffer) @@ -799,7 +1022,7 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - switch self.state { + switch self.state.value { case .consuming(let backpressurePromise): return backpressurePromise.futureResult @@ -812,28 +1035,33 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - switch self.state { + switch self.state.value { case .waitingForRemote(var promiseBuffer): - assert(!promiseBuffer.isEmpty, "assert expected to be waiting if we have at least one promise in the buffer") + assert( + !promiseBuffer.isEmpty, + "assert expected to be waiting if we have at least one promise in the buffer" + ) let promise = promiseBuffer.removeFirst() if promiseBuffer.isEmpty { let newBackpressurePromise = self.eventLoop.makePromise(of: Void.self) - self.state = .consuming(newBackpressurePromise) + self.state.value = .consuming(newBackpressurePromise) promise.succeed(buffer) return newBackpressurePromise.futureResult } else { - self.state = .waitingForRemote(promiseBuffer) + self.state.value = .waitingForRemote(promiseBuffer) promise.succeed(buffer) return self.eventLoop.makeSucceededVoidFuture() } case .buffering(.none): let promise = self.eventLoop.makePromise(of: Void.self) - self.state = .buffering((buffer, promise)) + self.state.value = .buffering((buffer, promise)) return promise.futureResult case .buffering(.some): - preconditionFailure("Did receive response part should not be called, before the previous promise was succeeded.") + preconditionFailure( + "Did receive response part should not be called, before the previous promise was succeeded." + ) case .done, .consuming: preconditionFailure("Invalid state: \(self.state)") @@ -841,21 +1069,23 @@ class ResponseBackpressureDelegate: HTTPClientResponseDelegate { } func didFinishRequest(task: HTTPClient.Task) throws { - switch self.state { + switch self.state.value { case .waitingForRemote(let promiseBuffer): - promiseBuffer.forEach { - $0.succeed(.none) + for promise in promiseBuffer { + promise.succeed(.none) } - self.state = .done + self.state.value = .done case .buffering(.none): - self.state = .done + self.state.value = .done case .done, .consuming: preconditionFailure("Invalid state: \(self.state)") case .buffering(.some): - preconditionFailure("Did receive response part should not be called, before the previous promise was succeeded.") + preconditionFailure( + "Did receive response part should not be called, before the previous promise was succeeded." + ) } } } @@ -873,7 +1103,7 @@ class ReadEventHitHandler: ChannelOutboundHandler { } } -final class FailEndHandler: ChannelOutboundHandler { +final class FailEndHandler: ChannelOutboundHandler, Sendable { typealias OutboundIn = HTTPClientRequestPart typealias OutboundOut = HTTPClientRequestPart diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift index e256aa49e..1c6e9659f 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionStateMachineTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOHTTP1 import NIOHTTPCompression import XCTest +@testable import AsyncHTTPClient + class HTTP1ConnectionStateMachineTests: XCTestCase { func testPOSTRequestWithWriteAndReadBackpressure() { var state = HTTP1ConnectionStateMachine() @@ -27,7 +28,10 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: false)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) @@ -51,7 +55,10 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.informConnectionIsIdle, .init([responseBody]))) @@ -66,10 +73,16 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "12"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -88,6 +101,26 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) } + func testWriteTimeoutAfterErrorDoesntCrash() { + var state = HTTP1ConnectionStateMachine() + XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) + + let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") + let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) + let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) + XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + + struct MyError: Error, Equatable {} + XCTAssertEqual(state.errorHappened(MyError()), .failRequest(MyError(), .close(nil))) + + // Primarily we care that we don't crash here + XCTAssertEqual(state.idleWriteTimeoutTriggered(), .wait) + } + func testAConnectionCloseHeaderInTheRequestLeadsToConnectionCloseAfterRequest() { var state = HTTP1ConnectionStateMachine() XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) @@ -95,10 +128,16 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: true, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) @@ -112,10 +151,16 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .ok, headers: ["content-length": "4"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) @@ -129,10 +174,20 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) - - let responseHead = HTTPResponseHead(version: .http1_0, status: .ok, headers: ["content-length": "4", "connection": "keep-alive"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_0, + status: .ok, + headers: ["content-length": "4", "connection": "keep-alive"] + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.informConnectionIsIdle, .init([responseBody]))) @@ -147,10 +202,16 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["connection": "close"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init([responseBody]))) @@ -191,13 +252,19 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: false)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) - XCTAssertEqual(state.requestCancelled(closeConnection: false), .failRequest(HTTPClientError.cancelled, .close(nil))) + XCTAssertEqual( + state.requestCancelled(closeConnection: false), + .failRequest(HTTPClientError.cancelled, .close(nil)) + ) } func testNewRequestAfterErrorHappened() { @@ -218,9 +285,17 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { XCTAssertEqual(state.channelActive(isWritable: true), .fireChannelActive) XCTAssertEqual(state.requestCancelled(closeConnection: false), .wait, "Should be ignored.") XCTAssertEqual(state.requestCancelled(closeConnection: true), .close, "Should lead to connection closure.") - XCTAssertEqual(state.requestCancelled(closeConnection: true), .wait, "Should be ignored. Connection is already closing") + XCTAssertEqual( + state.requestCancelled(closeConnection: true), + .wait, + "Should be ignored. Connection is already closing" + ) XCTAssertEqual(state.channelInactive(), .fireChannelInactive) - XCTAssertEqual(state.requestCancelled(closeConnection: true), .wait, "Should be ignored. Connection is already closed") + XCTAssertEqual( + state.requestCancelled(closeConnection: true), + .wait, + "Should be ignored. Connection is already closed" + ) } func testReadsAreForwardedIfConnectionIsClosing() { @@ -248,7 +323,10 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: ["content-length": "4"]) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) XCTAssertEqual(state.runNewRequest(head: requestHead, metadata: metadata), .wait) - XCTAssertEqual(state.requestCancelled(closeConnection: false), .failRequest(HTTPClientError.cancelled, .informConnectionIsIdle)) + XCTAssertEqual( + state.requestCancelled(closeConnection: false), + .failRequest(HTTPClientError.cancelled, .informConnectionIsIdle) + ) } func testConnectionIsClosedIfErrorHappensWhileInRequest() { @@ -258,9 +336,15 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Hello world!\n"))), .wait) XCTAssertEqual(state.channelRead(.body(ByteBuffer(string: "Foo Bar!\n"))), .wait) let decompressionError = NIOHTTPDecompression.DecompressionError.limit @@ -274,9 +358,15 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .switchingProtocols) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, [])) } @@ -287,8 +377,14 @@ class HTTP1ConnectionStateMachineTests: XCTestCase { let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) let newRequestAction = state.runNewRequest(head: requestHead, metadata: metadata) XCTAssertEqual(newRequestAction, .sendRequestHead(requestHead, sendEnd: true)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .init(statusCode: 103, reasonPhrase: "Early Hints")) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: false, startIdleTimer: true) + ) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .init(statusCode: 103, reasonPhrase: "Early Hints") + ) XCTAssertEqual(state.channelRead(.head(responseHead)), .wait) XCTAssertEqual(state.channelInactive(), .failRequest(HTTPClientError.remoteConnectionClosed, .none)) } @@ -339,13 +435,19 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { case (.resumeRequestBodyStream, .resumeRequestBodyStream): return true - case (.forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream)): + case ( + .forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), + .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream) + ): return lhsHead == rhsHead && lhsPauseRequestBodyStream == rhsPauseRequestBodyStream case (.forwardResponseBodyParts(let lhsData), .forwardResponseBodyParts(let rhsData)): return lhsData == rhsData - case (.succeedRequest(let lhsFinalAction, let lhsFinalBuffer), .succeedRequest(let rhsFinalAction, let rhsFinalBuffer)): + case ( + .succeedRequest(let lhsFinalAction, let lhsFinalBuffer), + .succeedRequest(let rhsFinalAction, let rhsFinalBuffer) + ): return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): @@ -367,7 +469,10 @@ extension HTTP1ConnectionStateMachine.Action: Equatable { } extension HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction: Equatable { - public static func == (lhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction, rhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction) -> Bool { + public static func == ( + lhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction, + rhs: HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction + ) -> Bool { switch (lhs, rhs) { case (.close, .close): return true @@ -382,7 +487,10 @@ extension HTTP1ConnectionStateMachine.Action.FinalSuccessfulStreamAction: Equata } extension HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction: Equatable { - public static func == (lhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction, rhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction) -> Bool { + public static func == ( + lhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction, + rhs: HTTP1ConnectionStateMachine.Action.FinalFailedStreamAction + ) -> Bool { switch (lhs, rhs) { case (.close(let lhsPromise), .close(let rhsPromise)): return lhsPromise?.futureResult == rhsPromise?.futureResult diff --git a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift index 3ff73de06..53001b64b 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ConnectionTests.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOConcurrencyHelpers import NIOCore @@ -23,6 +22,8 @@ import NIOPosix import NIOTestUtils import XCTest +@testable import AsyncHTTPClient + class HTTP1ConnectionTests: XCTestCase { func testCreateNewConnectionWithDecompression() { let embedded = EmbeddedChannel() @@ -31,19 +32,23 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) var connection: HTTP1Connection? - XCTAssertNoThrow(connection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: MockHTTP1ConnectionDelegate(), - decompression: .enabled(limit: .ratio(4)), - logger: logger - )) + XCTAssertNoThrow( + connection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self)) - XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self)) + XCTAssertNotNil( + try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) + ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) - XCTAssertNoThrow(try connection?.close().wait()) + XCTAssertNoThrow(try connection?.sendableView.close().wait()) embedded.embeddedEventLoop.run() XCTAssert(!embedded.isActive) } @@ -54,17 +59,22 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 3000)).wait()) - XCTAssertNoThrow(try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: MockHTTP1ConnectionDelegate(), - decompression: .disabled, - logger: logger - )) + XCTAssertNoThrow( + try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + decompression: .disabled, + logger: logger + ) + ) XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: HTTPRequestEncoder.self)) - XCTAssertNotNil(try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self)) - XCTAssertThrowsError(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) { error in + XCTAssertNotNil( + try embedded.pipeline.syncOperations.handler(type: ByteToMessageHandler.self) + ) + XCTAssertThrowsError(try embedded.pipeline.syncOperations.handler(type: NIOHTTPResponseDecompressor.self)) { + error in XCTAssertEqual(error as? ChannelPipelineError, .notFound) } } @@ -78,13 +88,15 @@ class HTTP1ConnectionTests: XCTestCase { embedded.embeddedEventLoop.run() let logger = Logger(label: "test.http1.connection") - XCTAssertThrowsError(try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: MockHTTP1ConnectionDelegate(), - decompression: .disabled, - logger: logger - )) + XCTAssertThrowsError( + try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: MockHTTP1ConnectionDelegate(), + decompression: .disabled, + logger: logger + ) + ) } func testGETRequest() { @@ -96,8 +108,7 @@ class HTTP1ConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try server.stop()) } let logger = Logger(label: "test") - let delegate = MockHTTP1ConnectionDelegate() - delegate.closePromise = clientEL.makePromise(of: Void.self) + let delegate = MockHTTP1ConnectionDelegate(closePromise: clientEL.makePromise()) let connection = try! ClientBootstrap(group: clientEL) .connect(to: .init(ipAddress: "127.0.0.1", port: server.serverPort)) @@ -108,35 +119,37 @@ class HTTP1ConnectionTests: XCTestCase { delegate: delegate, decompression: .disabled, logger: logger - ) + ).sendableView } .wait() var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost/hello/swift", - method: .POST, - body: .stream(length: 4) { writer -> EventLoopFuture in - func recursive(count: UInt8, promise: EventLoopPromise) { - guard count < 4 else { - return promise.succeed(()) - } + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/hello/swift", + method: .POST, + body: .stream(contentLength: 4) { writer -> EventLoopFuture in + @Sendable func recursive(count: UInt8, promise: EventLoopPromise) { + guard count < 4 else { + return promise.succeed(()) + } - writer.write(.byteBuffer(ByteBuffer(bytes: [count]))).whenComplete { result in - switch result { - case .failure(let error): - XCTFail("Unexpected error: \(error)") - case .success: - recursive(count: count + 1, promise: promise) + writer.write(.byteBuffer(ByteBuffer(bytes: [count]))).whenComplete { result in + switch result { + case .failure(let error): + XCTFail("Unexpected error: \(error)") + case .success: + recursive(count: count + 1, promise: promise) + } } } - } - let promise = clientEL.makePromise(of: Void.self) - recursive(count: 0, promise: promise) - return promise.futureResult - } - )) + let promise = clientEL.makePromise(of: Void.self) + recursive(count: 0, promise: promise) + return promise.futureResult + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a connection and a request") @@ -145,33 +158,39 @@ class HTTP1ConnectionTests: XCTestCase { let task = HTTPClient.Task(eventLoop: clientEL, logger: logger) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: clientEL), - task: task, - redirectHandler: nil, - connectionDeadline: .now() + .seconds(60), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: request) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: clientEL), + task: task, + redirectHandler: nil, + connectionDeadline: .now() + .seconds(60), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: request) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } connection.executeRequest(requestBag) - XCTAssertNoThrow(try server.receiveHeadAndVerify { head in - XCTAssertEqual(head.method, .POST) - XCTAssertEqual(head.uri, "/hello/swift") - XCTAssertEqual(head.headers["content-length"].first, "4") - }) + XCTAssertNoThrow( + try server.receiveHeadAndVerify { head in + XCTAssertEqual(head.method, .POST) + XCTAssertEqual(head.uri, "/hello/swift") + XCTAssertEqual(head.headers["content-length"].first, "4") + } + ) var received: UInt8 = 0 while received < 4 { - XCTAssertNoThrow(try server.receiveBodyAndVerify { body in - var body = body - while let read = body.readInteger(as: UInt8.self) { - XCTAssertEqual(received, read) - received += 1 + XCTAssertNoThrow( + try server.receiveBodyAndVerify { body in + var body = body + while let read = body.readInteger(as: UInt8.self) { + XCTAssertEqual(received, read) + received += 1 + } } - }) + ) } XCTAssertEqual(received, 4) XCTAssertNoThrow(try server.receiveEnd()) @@ -198,17 +217,23 @@ class HTTP1ConnectionTests: XCTestCase { var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait() + ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") - var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( - channel: XCTUnwrap(maybeChannel), - connectionID: 0, - delegate: connectionDelegate, - decompression: .disabled, - logger: logger - ) }.wait()) + var maybeConnection: HTTP1Connection.SendableView? + XCTAssertNoThrow( + maybeConnection = try eventLoop.submit { [maybeChannel] in + try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ).sendableView + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } var maybeRequest: HTTPClient.Request? @@ -217,15 +242,17 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: eventLoopGroup.next()), - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } connection.executeRequest(requestBag) @@ -248,21 +275,29 @@ class HTTP1ConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let closeOnRequest = (30...100).randomElement()! - let httpBin = HTTPBin(handlerFactory: { _ in SuddenlySendsCloseHeaderChannelHandler(closeOnRequest: closeOnRequest) }) + let httpBin = HTTPBin(handlerFactory: { _ in + SuddenlySendsCloseHeaderChannelHandler(closeOnRequest: closeOnRequest) + }) var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait() + ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") - var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( - channel: XCTUnwrap(maybeChannel), - connectionID: 0, - delegate: connectionDelegate, - decompression: .disabled, - logger: logger - ) }.wait()) + var maybeConnection: HTTP1Connection.SendableView? + XCTAssertNoThrow( + maybeConnection = try eventLoop.submit { [maybeChannel] in + try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ).sendableView + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } var counter = 0 @@ -275,16 +310,20 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: eventLoopGroup.next()), - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) - guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to be able to create a request bag") + } connection.executeRequest(requestBag) @@ -293,7 +332,7 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual(response?.status, .ok) if response?.headers.first(name: "connection") == "close" { - break // the loop + break // the loop } else { XCTAssertEqual(httpBin.activeConnections, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, counter) @@ -306,8 +345,11 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual(counter, closeOnRequest) XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) - XCTAssertEqual(connectionDelegate.hitConnectionReleased, counter - 1, - "If a close header is received connection release is not triggered.") + XCTAssertEqual( + connectionDelegate.hitConnectionReleased, + counter - 1, + "If a close header is received connection release is not triggered." + ) // we need to wait a small amount of time to see the connection close on the server try! eventLoop.scheduleTask(in: .milliseconds(200)) {}.futureResult.wait() @@ -324,17 +366,23 @@ class HTTP1ConnectionTests: XCTestCase { var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoop).connect(host: "localhost", port: httpBin.port).wait() + ) let connectionDelegate = MockConnectionDelegate() let logger = Logger(label: "test") - var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try eventLoop.submit { try HTTP1Connection.start( - channel: XCTUnwrap(maybeChannel), - connectionID: 0, - delegate: connectionDelegate, - decompression: .disabled, - logger: logger - ) }.wait()) + var maybeConnection: HTTP1Connection.SendableView? + XCTAssertNoThrow( + maybeConnection = try eventLoop.submit { [maybeChannel] in + try HTTP1Connection.start( + channel: XCTUnwrap(maybeChannel), + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ).sendableView + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection here") } var maybeRequest: HTTPClient.Request? @@ -343,15 +391,17 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: eventLoopGroup.next()), - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: eventLoopGroup.next()), + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } connection.executeRequest(requestBag) @@ -373,13 +423,15 @@ class HTTP1ConnectionTests: XCTestCase { var maybeConnection: HTTP1Connection? let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - decompression: .enabled(limit: .ratio(4)), - logger: logger - )) + XCTAssertNoThrow( + maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } var maybeRequest: HTTPClient.Request? @@ -388,38 +440,40 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - connection.executeRequest(requestBag) + connection.sendableView.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end let responseString = """ - HTTP/1.1 101 Switching Protocols\r\n\ - Upgrade: websocket\r\n\ - Sec-WebSocket-Accept: xAMUK7/Il9bLRFJrikq6mm8CNZI=\r\n\ - Connection: upgrade\r\n\ - date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ - \r\n\ - \r\nfoo bar baz - """ + HTTP/1.1 101 Switching Protocols\r\n\ + Upgrade: websocket\r\n\ + Sec-WebSocket-Accept: xAMUK7/Il9bLRFJrikq6mm8CNZI=\r\n\ + Connection: upgrade\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\nfoo bar baz + """ XCTAssertTrue(embedded.isActive) XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) XCTAssertNoThrow(try embedded.writeInbound(ByteBuffer(string: responseString))) XCTAssertFalse(embedded.isActive) - (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) @@ -438,13 +492,15 @@ class HTTP1ConnectionTests: XCTestCase { var maybeConnection: HTTP1Connection? let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - decompression: .enabled(limit: .ratio(4)), - logger: logger - )) + XCTAssertNoThrow( + maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } var maybeRequest: HTTPClient.Request? @@ -453,28 +509,30 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - connection.executeRequest(requestBag) + connection.sendableView.executeRequest(requestBag) - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end let responseString = """ - HTTP/1.1 103 Early Hints\r\n\ - date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ - \r\n\ - \r\n - """ + HTTP/1.1 103 Early Hints\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\n + """ XCTAssertTrue(embedded.isActive) XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) @@ -484,7 +542,7 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertTrue(embedded.isActive, "The connection remains active after the informational response head") XCTAssertNoThrow(try embedded.close().wait(), "the connection was closed") - embedded.embeddedEventLoop.run() // tick once to run futures. + embedded.embeddedEventLoop.run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) @@ -500,20 +558,22 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: SocketAddress(ipAddress: "127.0.0.1", port: 0)).wait()) let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - decompression: .enabled(limit: .ratio(4)), - logger: logger - )) + XCTAssertNoThrow( + try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) let responseString = """ - HTTP/1.1 200 OK\r\n\ - date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ - \r\n\ - \r\n - """ + HTTP/1.1 200 OK\r\n\ + date: Mon, 27 Sep 2021 17:53:14 GMT\r\n\ + \r\n\ + \r\n + """ XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) @@ -522,7 +582,7 @@ class HTTP1ConnectionTests: XCTestCase { XCTAssertEqual($0 as? NIOHTTPDecoderError, .unsolicitedResponse) } XCTAssertFalse(embedded.isActive) - (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) } @@ -535,13 +595,15 @@ class HTTP1ConnectionTests: XCTestCase { var maybeConnection: HTTP1Connection? let connectionDelegate = MockConnectionDelegate() - XCTAssertNoThrow(maybeConnection = try HTTP1Connection.start( - channel: embedded, - connectionID: 0, - delegate: connectionDelegate, - decompression: .enabled(limit: .ratio(4)), - logger: logger - )) + XCTAssertNoThrow( + maybeConnection = try HTTP1Connection.start( + channel: embedded, + connectionID: 0, + delegate: connectionDelegate, + decompression: .enabled(limit: .ratio(4)), + logger: logger + ) + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point.") } var maybeRequest: HTTPClient.Request? @@ -550,32 +612,34 @@ class HTTP1ConnectionTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } - connection.executeRequest(requestBag) + connection.sendableView.executeRequest(requestBag) let responseString = """ - HTTP/1.0 200 OK\r\n\ - HTTP/1.0 200 OK\r\n\r\n - """ + HTTP/1.0 200 OK\r\n\ + HTTP/1.0 200 OK\r\n\r\n + """ - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head - XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // head + XCTAssertNoThrow(try embedded.readOutbound(as: ByteBuffer.self)) // end XCTAssertEqual(connectionDelegate.hitConnectionClosed, 0) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) XCTAssertNoThrow(try embedded.writeInbound(ByteBuffer(string: responseString))) XCTAssertFalse(embedded.isActive) - (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. + (embedded.eventLoop as! EmbeddedEventLoop).run() // tick once to run futures. XCTAssertEqual(connectionDelegate.hitConnectionClosed, 1) XCTAssertEqual(connectionDelegate.hitConnectionReleased, 0) } @@ -589,42 +653,42 @@ class HTTP1ConnectionTests: XCTestCase { // bytes a ready to be read as well. This will allow us to test if subsequent reads // are waiting for backpressure promise. func testDownloadStreamingBackpressure() { - class BackpressureTestDelegate: HTTPClientResponseDelegate { + final class BackpressureTestDelegate: HTTPClientResponseDelegate { typealias Response = Void - var _reads = 0 - var _channel: Channel? + private struct State: Sendable { + var reads = 0 + var channel: Channel? + } + + private let state = NIOLockedValueBox(State()) + + var reads: Int { + self.state.withLockedValue { $0.reads } + } - let lock: NIOLock let backpressurePromise: EventLoopPromise let messageReceived: EventLoopPromise init(eventLoop: EventLoop) { - self.lock = NIOLock() self.backpressurePromise = eventLoop.makePromise() self.messageReceived = eventLoop.makePromise() } - var reads: Int { - return self.lock.withLock { - self._reads - } - } - func willExecuteOnChannel(_ channel: Channel) { - self.lock.withLock { - self._channel = channel + self.state.withLockedValue { + $0.channel = channel } } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - return task.futureResult.eventLoop.makeSucceededVoidFuture() + task.futureResult.eventLoop.makeSucceededVoidFuture() } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { // We count a number of reads received. - self.lock.withLock { - self._reads += 1 + self.state.withLockedValue { + $0.reads += 1 } // We need to notify the test when first byte of the message is arrived. self.messageReceived.succeed(()) @@ -656,8 +720,8 @@ class HTTP1ConnectionTests: XCTestCase { let buffer = context.channel.allocator.buffer(string: "1234") context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(buffer))), promise: nil) - self.endFuture.hop(to: context.eventLoop).whenSuccess { - context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) + self.endFuture.hop(to: context.eventLoop).assumeIsolated().whenSuccess { + context.writeAndFlush(Self.wrapOutboundOut(.end(nil)), promise: nil) } } } @@ -679,34 +743,42 @@ class HTTP1ConnectionTests: XCTestCase { defer { XCTAssertNoThrow(try httpBin.shutdown()) } var maybeChannel: Channel? - XCTAssertNoThrow(maybeChannel = try ClientBootstrap(group: eventLoopGroup) - .channelOption(ChannelOptions.maxMessagesPerRead, value: 1) - .channelOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1)) - .connect(host: "localhost", port: httpBin.port) - .wait()) + XCTAssertNoThrow( + maybeChannel = try ClientBootstrap(group: eventLoopGroup) + .channelOption(ChannelOptions.maxMessagesPerRead, value: 1) + .channelOption(ChannelOptions.recvAllocator, value: FixedSizeRecvByteBufferAllocator(capacity: 1)) + .connect(host: "localhost", port: httpBin.port) + .wait() + ) guard let channel = maybeChannel else { return XCTFail("Expected to have a channel at this point") } let connectionDelegate = MockConnectionDelegate() - var maybeConnection: HTTP1Connection? - XCTAssertNoThrow(maybeConnection = try channel.eventLoop.submit { try HTTP1Connection.start( - channel: channel, - connectionID: 0, - delegate: connectionDelegate, - decompression: .disabled, - logger: logger - ) }.wait()) + var maybeConnection: HTTP1Connection.SendableView? + XCTAssertNoThrow( + maybeConnection = try channel.eventLoop.submit { + try HTTP1Connection.start( + channel: channel, + connectionID: 0, + delegate: connectionDelegate, + decompression: .disabled, + logger: logger + ).sendableView + }.wait() + ) guard let connection = maybeConnection else { return XCTFail("Expected to have a connection at this point") } var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: HTTPClient.Request(url: "http://localhost:\(httpBin.port)/custom"), - eventLoopPreference: .delegate(on: requestEventLoop), - task: .init(eventLoop: requestEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: backpressureDelegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: HTTPClient.Request(url: "http://localhost:\(httpBin.port)/custom"), + eventLoopPreference: .delegate(on: requestEventLoop), + task: .init(eventLoop: requestEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: backpressureDelegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } backpressureDelegate.willExecuteOnChannel(connection.channel) @@ -729,15 +801,20 @@ class HTTP1ConnectionTests: XCTestCase { } } -class MockHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { - var releasePromise: EventLoopPromise? - var closePromise: EventLoopPromise? +final class MockHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { + let releasePromise: EventLoopPromise? + let closePromise: EventLoopPromise? + + init(releasePromise: EventLoopPromise? = nil, closePromise: EventLoopPromise? = nil) { + self.releasePromise = releasePromise + self.closePromise = closePromise + } - func http1ConnectionReleased(_: HTTP1Connection) { + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) { self.releasePromise?.succeed(()) } - func http1ConnectionClosed(_: HTTP1Connection) { + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { self.closePromise?.succeed(()) } } @@ -764,7 +841,12 @@ class SuddenlySendsCloseHeaderChannelHandler: ChannelInboundHandler { break case .end: if self.closeOnRequest == self.counter { - context.write(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: ["connection": "close"]))), promise: nil) + context.write( + self.wrapOutboundOut( + .head(.init(version: .http1_1, status: .ok, headers: ["connection": "close"])) + ), + promise: nil + ) context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() self.counter += 1 @@ -797,38 +879,40 @@ class AfterRequestCloseConnectionChannelHandler: ChannelInboundHandler { context.write(self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() - context.eventLoop.scheduleTask(in: .milliseconds(20)) { + context.eventLoop.assumeIsolated().scheduleTask(in: .milliseconds(20)) { context.close(promise: nil) } } } } -class MockConnectionDelegate: HTTP1ConnectionDelegate { - private var lock = NIOLock() +final class MockConnectionDelegate: HTTP1ConnectionDelegate { + private let counts = NIOLockedValueBox(Counts()) - private var _hitConnectionReleased = 0 - private var _hitConnectionClosed = 0 + private struct Counts: Sendable { + var hitConnectionReleased = 0 + var hitConnectionClosed = 0 + } var hitConnectionReleased: Int { - self.lock.withLock { self._hitConnectionReleased } + self.counts.withLockedValue { $0.hitConnectionReleased } } var hitConnectionClosed: Int { - self.lock.withLock { self._hitConnectionClosed } + self.counts.withLockedValue { $0.hitConnectionClosed } } init() {} - func http1ConnectionReleased(_: HTTP1Connection) { - self.lock.withLock { - self._hitConnectionReleased += 1 + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) { + self.counts.withLockedValue { + $0.hitConnectionReleased += 1 } } - func http1ConnectionClosed(_: HTTP1Connection) { - self.lock.withLock { - self._hitConnectionClosed += 1 + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { + self.counts.withLockedValue { + $0.hitConnectionClosed += 1 } } } diff --git a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift index b3917173f..d75865da2 100644 --- a/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP1ProxyConnectHandlerTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class HTTP1ProxyConnectHandlerTests: XCTestCase { func testProxyConnectWithoutAuthorizationSuccess() { let embedded = EmbeddedChannel() diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift index 545ba1e3c..71f7f3d1a 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientRequestHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class HTTP2ClientRequestHandlerTests: XCTestCase { func testResponseBackpressure() { let embedded = EmbeddedChannel() @@ -34,28 +35,36 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.write(requestBag, promise: nil) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertEqual(try embedded.readOutbound(as: HTTPClientRequestPart.self), .end(nil)) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(readEventHandler.readHitCounter, 0) embedded.read() @@ -115,22 +124,30 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 50) var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 100) { writer in - testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 100) { writer in + testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.isWritable = false @@ -143,12 +160,14 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { testWriter.writabilityChanged(true) embedded.pipeline.fireChannelWritabilityChanged() - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .POST) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - XCTAssertEqual($0.headers.first(name: "content-length"), "100") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .POST) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + XCTAssertEqual($0.headers.first(name: "content-length"), "100") + } + ) // the next body write will be executed once we tick the el. before we make the channel // unwritable @@ -162,9 +181,11 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { embedded.embeddedEventLoop.run() - XCTAssertNoThrow(try embedded.receiveBodyAndVerify { - XCTAssertEqual($0.readableBytes, 2) - }) + XCTAssertNoThrow( + try embedded.receiveBodyAndVerify { + XCTAssertEqual($0.readableBytes, 2) + } + ) XCTAssertEqual(testWriter.written, index + 1) @@ -198,27 +219,35 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.write(requestBag, promise: nil) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(readEventHandler.readHitCounter, 0) embedded.read() @@ -248,27 +277,35 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.write(requestBag, promise: nil) - XCTAssertNoThrow(try embedded.receiveHeadAndVerify { - XCTAssertEqual($0.method, .GET) - XCTAssertEqual($0.uri, "/") - XCTAssertEqual($0.headers.first(name: "host"), "localhost") - }) + XCTAssertNoThrow( + try embedded.receiveHeadAndVerify { + XCTAssertEqual($0.method, .GET) + XCTAssertEqual($0.uri, "/") + XCTAssertEqual($0.headers.first(name: "host"), "localhost") + } + ) XCTAssertNoThrow(try embedded.receiveEnd()) - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) XCTAssertEqual(readEventHandler.readHitCounter, 0) embedded.read() @@ -295,24 +332,32 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 10) { writer in - // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. - embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) - return testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + // Advance time by more than the idle write timeout (that's 1 millisecond) to trigger the timeout. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + return testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.isWritable = true @@ -335,34 +380,42 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 10) { writer in - embedded.isWritable = false - embedded.pipeline.fireChannelWritabilityChanged() - // This should not trigger any errors or timeouts, because the timer isn't running - // as the channel is not writable. - embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) - - // Now that the channel will become writable, this should trigger a timeout. - embedded.isWritable = true - embedded.pipeline.fireChannelWritabilityChanged() - embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) - - return testWriter.start(writer: writer) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 10) { writer in + embedded.isWritable = false + embedded.pipeline.fireChannelWritabilityChanged() + // This should not trigger any errors or timeouts, because the timer isn't running + // as the channel is not writable. + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(20)) + + // Now that the channel will become writable, this should trigger a timeout. + embedded.isWritable = true + embedded.pipeline.fireChannelWritabilityChanged() + embedded.embeddedEventLoop.advanceTime(by: .milliseconds(2)) + + return testWriter.start(writer: writer) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.isWritable = true @@ -385,22 +438,30 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let testWriter = TestBackpressureWriter(eventLoop: embedded.eventLoop, parts: 5) var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost/", method: .POST, body: .stream(length: 2) { writer in - return testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) - })) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost/", + method: .POST, + body: .stream(contentLength: 2) { writer in + testWriter.start(writer: writer, expectedErrors: [HTTPClientError.cancelled]) + } + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to be able to create a request") } let delegate = ResponseBackpressureDelegate(eventLoop: embedded.eventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleWriteTimeout: .milliseconds(1)), + delegate: delegate + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } embedded.isWritable = true @@ -451,16 +512,20 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let delegate = ResponseAccumulator(request: request) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embedded.eventLoop), - task: .init(eventLoop: embedded.eventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), - delegate: delegate - )) - guard let requestBag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag") } + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embedded.eventLoop), + task: .init(eventLoop: embedded.eventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(idleReadTimeout: .milliseconds(200)), + delegate: delegate + ) + ) + guard let requestBag = maybeRequestBag else { + return XCTFail("Expected to be able to create a request bag") + } embedded.isWritable = false XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) @@ -494,16 +559,20 @@ class HTTP2ClientRequestHandlerTests: XCTestCase { let handler = HTTP2ClientRequestHandler( eventLoop: eventLoop ) - let channel = EmbeddedChannel(handlers: [ - ChangeWritabilityOnFlush(), - handler, - ], loop: eventLoop) + let channel = EmbeddedChannel( + handlers: [ + ChangeWritabilityOnFlush(), + handler, + ], + loop: eventLoop + ) try channel.connect(to: .init(ipAddress: "127.0.0.1", port: 80)).wait() - let request = MockHTTPExecutableRequest() // non empty body is important to trigger this bug as we otherwise finish the request in a single flush - request.requestFramingMetadata.body = .fixedSize(1) - request.raiseErrorIfUnimplementedMethodIsCalled = false + let request = MockHTTPExecutableRequest( + framingMetadata: RequestFramingMetadata(connectionClose: false, body: .fixedSize(1)), + raiseErrorIfUnimplementedMethodIsCalled: false + ) channel.writeAndFlush(request, promise: nil) XCTAssertEqual(request.events.map(\.kind), [.willExecuteRequest, .requestHeadSent]) } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift index 97f0385ea..183a227bd 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift @@ -12,17 +12,21 @@ // //===----------------------------------------------------------------------===// -/* NOT @testable */ import AsyncHTTPClient // Tests that really need @testable go into HTTP2ClientInternalTests.swift -#if canImport(Network) -import Network -#endif +import AsyncHTTPClient // NOT @testable - tests that really need @testable go into HTTP2ClientInternalTests.swift import Logging +import NIOConcurrencyHelpers import NIOCore +import NIOFoundationCompat import NIOHTTP1 +import NIOHTTP2 import NIOPosix import NIOSSL import XCTest +#if canImport(Network) +import Network +#endif + class HTTP2ClientTests: XCTestCase { func makeDefaultHTTPClient( eventLoopGroupProvider: HTTPClient.EventLoopGroupProvider = .singleton @@ -68,7 +72,7 @@ class HTTP2ClientTests: XCTestCase { let client = self.makeDefaultHTTPClient() defer { XCTAssertNoThrow(try client.syncShutdown()) } var response: HTTPClient.Response? - let body = HTTPClient.Body.stream(length: nil) { writer in + let body = HTTPClient.Body.stream(contentLength: nil) { writer in writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))).flatMap { writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))) } @@ -84,7 +88,7 @@ class HTTP2ClientTests: XCTestCase { defer { XCTAssertNoThrow(try bin.shutdown()) } let client = self.makeDefaultHTTPClient() defer { XCTAssertNoThrow(try client.syncShutdown()) } - let body = HTTPClient.Body.stream(length: 12) { writer in + let body = HTTPClient.Body.stream(contentLength: 12) { writer in writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))).flatMap { writer.write(.byteBuffer(ByteBuffer(integer: UInt64(0)))) } @@ -132,8 +136,8 @@ class HTTP2ClientTests: XCTestCase { let q = DispatchQueue(label: "worker \(w)") q.async(group: allDone) { func go() { - allWorkersReady.signal() // tell the driver we're ready - allWorkersGo.wait() // wait for the driver to let us go + allWorkersReady.signal() // tell the driver we're ready + allWorkersGo.wait() // wait for the driver to let us go for _ in 0..] = [] - XCTAssertNoThrow(results = try EventLoopFuture - .whenAllComplete(responses, on: clientGroup.next()) - .timeout(after: .seconds(2)) - .wait()) + XCTAssertNoThrow( + results = + try EventLoopFuture + .whenAllComplete(responses, on: clientGroup.next()) + .timeout(after: .seconds(2)) + .wait() + ) for result in results { switch result { @@ -276,15 +284,16 @@ class HTTP2ClientTests: XCTestCase { XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(bin.port)")) guard let request = maybeRequest else { return } - var task: HTTPClient.Task! + let taskBox = NIOLockedValueBox?>(nil) let delegate = HeadReceivedCallback { _ in // request is definitely running because we just received a head from the server - task.cancel() + taskBox.withLockedValue { $0 }!.cancel() } - task = client.execute( + let task = client.execute( request: request, delegate: delegate ) + taskBox.withLockedValue { $0 = task } XCTAssertThrowsError(try task.futureResult.timeout(after: .seconds(2)).wait()) { XCTAssertEqualTypeAndValue($0, HTTPClientError.cancelled) @@ -353,18 +362,20 @@ class HTTP2ClientTests: XCTestCase { guard let request = maybeRequest else { return } let tasks = (0..<100).map { _ -> HTTPClient.Task in - var task: HTTPClient.Task! + let taskBox = NIOLockedValueBox?>(nil) + let delegate = HeadReceivedCallback { _ in // request is definitely running because we just received a head from the server cancelPool.next().execute { // canceling from a different thread - task.cancel() + taskBox.withLockedValue { $0 }!.cancel() } } - task = client.execute( + let task = client.execute( request: request, delegate: delegate ) + taskBox.withLockedValue { $0 = task } return task } @@ -397,7 +408,11 @@ class HTTP2ClientTests: XCTestCase { XCTAssertNoThrow(maybeRequest1 = try HTTPClient.Request(url: "https://localhost:\(bin.port)/get")) guard let request1 = maybeRequest1 else { return } - let task1 = client.execute(request: request1, delegate: ResponseAccumulator(request: request1), eventLoop: .delegateAndChannel(on: el1)) + let task1 = client.execute( + request: request1, + delegate: ResponseAccumulator(request: request1), + eventLoop: .delegateAndChannel(on: el1) + ) var response1: ResponseAccumulator.Response? XCTAssertNoThrow(response1 = try task1.wait()) @@ -408,15 +423,17 @@ class HTTP2ClientTests: XCTestCase { let serverGroup = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try serverGroup.syncShutdownGracefully()) } var maybeServer: Channel? - XCTAssertNoThrow(maybeServer = try ServerBootstrap(group: serverGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: 1) - .childChannelInitializer { channel in - channel.close() - } - .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .bind(host: "127.0.0.1", port: serverPort) - .wait()) + XCTAssertNoThrow( + maybeServer = try ServerBootstrap(group: serverGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: 1) + .childChannelInitializer { channel in + channel.close() + } + .childChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .bind(host: "127.0.0.1", port: serverPort) + .wait() + ) // shutting down the old server closes all connections immediately XCTAssertNoThrow(try bin.shutdown()) // client is now in HTTP/2 state and the HTTPBin is closed @@ -427,7 +444,11 @@ class HTTP2ClientTests: XCTestCase { XCTAssertNoThrow(maybeRequest2 = try HTTPClient.Request(url: "https://localhost:\(serverPort)/")) guard let request2 = maybeRequest2 else { return } - let task2 = client.execute(request: request2, delegate: ResponseAccumulator(request: request2), eventLoop: .delegateAndChannel(on: el2)) + let task2 = client.execute( + request: request2, + delegate: ResponseAccumulator(request: request2), + eventLoop: .delegateAndChannel(on: el2) + ) XCTAssertThrowsError(try task2.wait()) { error in XCTAssertNil( error as? HTTPClientError, @@ -448,12 +469,90 @@ class HTTP2ClientTests: XCTestCase { XCTAssertEqual(response?.version, .http2) XCTAssertEqual(response?.body?.readableBytes, 10_000) } + + func testSimplePost() { + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + let client = self.makeDefaultHTTPClient() + defer { XCTAssertNoThrow(try client.syncShutdown()) } + var response: HTTPClient.Response? + XCTAssertNoThrow( + response = try client.post( + url: "https://localhost:\(bin.port)/post", + body: .byteBuffer(ByteBuffer(repeating: 0, count: 12345)) + ).wait() + ) + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http2) + XCTAssertEqual( + String(buffer: ByteBuffer(repeating: 0, count: 12345)), + try response?.body.map { body in + try JSONDecoder().decode(RequestInfo.self, from: body) + }?.data + ) + } + + func testHugePost() { + // Regression test for https://github.com/swift-server/async-http-client/issues/784 + let group = MultiThreadedEventLoopGroup(numberOfThreads: 2) // This needs to be more than 1! + defer { + XCTAssertNoThrow(try group.syncShutdownGracefully()) + } + var serverH2Settings: HTTP2Settings = HTTP2Settings() + serverH2Settings.append(HTTP2Setting(parameter: .maxFrameSize, value: 16 * 1024 * 1024 - 1)) + serverH2Settings.append(HTTP2Setting(parameter: .initialWindowSize, value: Int(Int32.max))) + let bin = HTTPBin( + .http2(compress: false, settings: serverH2Settings) + ) + defer { XCTAssertNoThrow(try bin.shutdown()) } + var clientConfig = HTTPClient.Configuration() + clientConfig.tlsConfiguration = .clientDefault + clientConfig.tlsConfiguration?.certificateVerification = .none + clientConfig.httpVersion = .automatic + let client = HTTPClient( + eventLoopGroupProvider: .shared(group), + configuration: clientConfig, + backgroundActivityLogger: Logger(label: "HTTPClient", factory: StreamLogHandler.standardOutput(label:)) + ) + defer { XCTAssertNoThrow(try client.syncShutdown()) } + + let loop1 = group.next() + let loop2 = group.next() + precondition(loop1 !== loop2, "bug in test setup, need two distinct loops") + + XCTAssertNoThrow( + try client.execute( + request: .init(url: "https://localhost:\(bin.port)/get"), + eventLoop: .delegateAndChannel(on: loop1) // This will force the channel to live on `loop1`. + ).wait() + ) + var response: HTTPClient.Response? + let byteCount = 1024 * 1024 * 1024 // 1 GiB (unfortunately it has to be that big to trigger the bug) + XCTAssertNoThrow( + response = try client.execute( + request: HTTPClient.Request( + url: "https://localhost:\(bin.port)/post-respond-with-byte-count", + method: .POST, + body: .data(Data(repeating: 0, count: byteCount)) + ), + eventLoop: .delegate(on: loop2) + ).wait() + ) + XCTAssertEqual(.ok, response?.status) + XCTAssertEqual(response?.version, .http2) + XCTAssertEqual( + "\(byteCount)", + try response?.body.map { body in + try JSONDecoder().decode(RequestInfo.self, from: body) + }?.data + ) + } } private final class HeadReceivedCallback: HTTPClientResponseDelegate { typealias Response = Void - private let didReceiveHeadCallback: (HTTPResponseHead) -> Void - init(didReceiveHead: @escaping (HTTPResponseHead) -> Void) { + private let didReceiveHeadCallback: @Sendable (HTTPResponseHead) -> Void + init(didReceiveHead: @escaping @Sendable (HTTPResponseHead) -> Void) { self.didReceiveHeadCallback = didReceiveHead } @@ -474,11 +573,17 @@ private final class SendHeaderAndWaitChannelHandler: ChannelInboundHandler { let requestPart = self.unwrapInboundIn(data) switch requestPart { case .head: - context.writeAndFlush(self.wrapOutboundOut(.head(HTTPResponseHead( - version: HTTPVersion(major: 1, minor: 1), - status: .ok - )) - ), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut( + .head( + HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok + ) + ) + ), + promise: nil + ) case .body, .end: return } diff --git a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift index 15e5cdff2..3244e2b5a 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2ConnectionTests.swift @@ -12,17 +12,20 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOHPACK import NIOHTTP1 +import NIOHTTP2 import NIOPosix import NIOSSL import NIOTestUtils import XCTest +@testable import AsyncHTTPClient + class HTTP2ConnectionTests: XCTestCase { func testCreateNewConnectionFailureClosedIO() { let embedded = EmbeddedChannel() @@ -33,14 +36,16 @@ class HTTP2ConnectionTests: XCTestCase { embedded.embeddedEventLoop.run() let logger = Logger(label: "test.http2.connection") - XCTAssertThrowsError(try HTTP2Connection.start( - channel: embedded, - connectionID: 0, - delegate: TestHTTP2ConnectionDelegate(), - decompression: .disabled, - maximumConnectionUses: nil, - logger: logger - ).wait()) + XCTAssertThrowsError( + try HTTP2Connection.start( + channel: embedded, + connectionID: 0, + delegate: TestHTTP2ConnectionDelegate(), + decompression: .disabled, + maximumConnectionUses: nil, + logger: logger + ).map { _ in }.nonisolated().wait() + ) } func testConnectionToleratesShutdownEventsAfterAlreadyClosed() { @@ -65,7 +70,7 @@ class HTTP2ConnectionTests: XCTestCase { XCTAssertThrowsError(try startFuture.wait()) // should not crash - connection.shutdown() + connection.sendableView.shutdown() } func testSimpleGetRequest() { @@ -78,12 +83,13 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - ) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") @@ -92,15 +98,17 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } @@ -134,12 +142,14 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -153,15 +163,17 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } @@ -198,12 +210,13 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - ) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") @@ -215,15 +228,17 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } @@ -262,7 +277,7 @@ class HTTP2ConnectionTests: XCTestCase { func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.dataArrivedPromise.succeed(()) - self.triggerResponseFuture.hop(to: context.eventLoop).whenSuccess { + self.triggerResponseFuture.hop(to: context.eventLoop).assumeIsolated().whenSuccess { switch self.unwrapInboundIn(data) { case .head: context.write(self.wrapOutboundOut(.head(.init(version: .http2, status: .ok))), promise: nil) @@ -290,12 +305,14 @@ class HTTP2ConnectionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -303,15 +320,17 @@ class HTTP2ConnectionTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to have a request bag at this point") } @@ -320,7 +339,9 @@ class HTTP2ConnectionTests: XCTestCase { XCTAssertNoThrow(try serverReceivedRequestPromise.futureResult.wait()) var channelCount: Int? - XCTAssertNoThrow(channelCount = try eventLoop.submit { http2Connection.__forTesting_getStreamChannels().count }.wait()) + XCTAssertNoThrow( + channelCount = try eventLoop.submit { http2Connection.__forTesting_getStreamChannels().count }.wait() + ) XCTAssertEqual(channelCount, 1) triggerResponsePromise.succeed(()) @@ -330,7 +351,9 @@ class HTTP2ConnectionTests: XCTestCase { var retryCount = 0 let maxRetries = 1000 while retryCount < maxRetries { - XCTAssertNoThrow(channelCount = try eventLoop.submit { http2Connection.__forTesting_getStreamChannels().count }.wait()) + XCTAssertNoThrow( + channelCount = try eventLoop.submit { http2Connection.__forTesting_getStreamChannels().count }.wait() + ) if channelCount == 0 { break } @@ -338,9 +361,31 @@ class HTTP2ConnectionTests: XCTestCase { } XCTAssertLessThan(retryCount, maxRetries) } + + func testServerPushIsDisabled() { + let embedded = EmbeddedChannel() + let logger = Logger(label: "test.http2.connection") + let connection = HTTP2Connection( + channel: embedded, + connectionID: 0, + decompression: .disabled, + maximumConnectionUses: nil, + delegate: TestHTTP2ConnectionDelegate(), + logger: logger + ) + _ = connection._start0() + + let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([]))) + XCTAssertNoThrow(try connection.channel.writeAndFlush(settingsFrame).wait()) + + let pushPromiseFrame = HTTP2Frame(streamID: 0, payload: .pushPromise(.init(pushedStreamID: 1, headers: [:]))) + XCTAssertThrowsError(try connection.channel.writeAndFlush(pushPromiseFrame).wait()) { error in + XCTAssertNotNil(error as? NIOHTTP2Errors.PushInViolationOfSetting) + } + } } -class TestConnectionCreator { +final class TestConnectionCreator { enum Error: Swift.Error { case alreadyCreatingAnotherConnection case wantedHTTP2ConnectionButGotHTTP1 @@ -349,12 +394,11 @@ class TestConnectionCreator { enum State { case idle - case waitingForHTTP1Connection(EventLoopPromise) - case waitingForHTTP2Connection(EventLoopPromise) + case waitingForHTTP1Connection(EventLoopPromise) + case waitingForHTTP2Connection(EventLoopPromise) } - private var state: State = .idle - private let lock = NIOLock() + private let lock = NIOLockedValueBox(.idle) init() {} @@ -364,7 +408,7 @@ class TestConnectionCreator { connectionID: HTTPConnectionPool.Connection.ID = 0, on eventLoop: EventLoop, logger: Logger = .init(label: "test") - ) throws -> HTTP1Connection { + ) throws -> HTTP1Connection.SendableView { let request = try! HTTPClient.Request(url: "https://localhost:\(port)") var tlsConfiguration = TLSConfiguration.makeClientConfiguration() @@ -378,13 +422,13 @@ class TestConnectionCreator { sslContextCache: .init() ) - let promise = try self.lock.withLock { () -> EventLoopPromise in - guard case .idle = self.state else { + let promise = try self.lock.withLockedValue { state in + guard case .idle = state else { throw Error.alreadyCreatingAnotherConnection } - let promise = eventLoop.makePromise(of: HTTP1Connection.self) - self.state = .waitingForHTTP1Connection(promise) + let promise = eventLoop.makePromise(of: HTTP1Connection.SendableView.self) + state = .waitingForHTTP1Connection(promise) return promise } @@ -407,7 +451,7 @@ class TestConnectionCreator { connectionID: HTTPConnectionPool.Connection.ID = 0, on eventLoop: EventLoop, logger: Logger = .init(label: "test") - ) throws -> HTTP2Connection { + ) throws -> HTTP2Connection.SendableView { let request = try! HTTPClient.Request(url: "https://localhost:\(port)") var tlsConfiguration = TLSConfiguration.makeClientConfiguration() @@ -421,13 +465,13 @@ class TestConnectionCreator { sslContextCache: .init() ) - let promise = try self.lock.withLock { () -> EventLoopPromise in - guard case .idle = self.state else { + let promise = try self.lock.withLockedValue { state in + guard case .idle = state else { throw Error.alreadyCreatingAnotherConnection } - let promise = eventLoop.makePromise(of: HTTP2Connection.self) - self.state = .waitingForHTTP2Connection(promise) + let promise = eventLoop.makePromise(of: HTTP2Connection.SendableView.self) + state = .waitingForHTTP2Connection(promise) return promise } @@ -446,7 +490,7 @@ class TestConnectionCreator { } extension TestConnectionCreator: HTTPConnectionRequester { - enum EitherPromiseWrapper { + enum EitherPromiseWrapper: Sendable { case succeed(EventLoopPromise, SucceedType) case fail(EventLoopPromise, Error) @@ -460,37 +504,38 @@ extension TestConnectionCreator: HTTPConnectionRequester { } } - func http1ConnectionCreated(_ connection: HTTP1Connection) { - let wrapper = self.lock.withLock { () -> (EitherPromiseWrapper) in + func http1ConnectionCreated(_ connection: HTTP1Connection.SendableView) { + let wrapper: EitherPromiseWrapper = self.lock + .withLockedValue { state in - switch self.state { - case .waitingForHTTP1Connection(let promise): - return .succeed(promise, connection) + switch state { + case .waitingForHTTP1Connection(let promise): + return .succeed(promise, connection) - case .waitingForHTTP2Connection(let promise): - return .fail(promise, Error.wantedHTTP2ConnectionButGotHTTP1) + case .waitingForHTTP2Connection(let promise): + return .fail(promise, Error.wantedHTTP2ConnectionButGotHTTP1) - case .idle: - preconditionFailure("Invalid state: \(self.state)") + case .idle: + preconditionFailure("Invalid state: \(state)") + } } - } wrapper.complete() } - func http2ConnectionCreated(_ connection: HTTP2Connection, maximumStreams: Int) { - let wrapper = self.lock.withLock { () -> (EitherPromiseWrapper) in + func http2ConnectionCreated(_ connection: HTTP2Connection.SendableView, maximumStreams: Int) { + let wrapper: EitherPromiseWrapper = self.lock + .withLockedValue { state in + switch state { + case .waitingForHTTP1Connection(let promise): + return .fail(promise, Error.wantedHTTP1ConnectionButGotHTTP2) - switch self.state { - case .waitingForHTTP1Connection(let promise): - return .fail(promise, Error.wantedHTTP1ConnectionButGotHTTP2) + case .waitingForHTTP2Connection(let promise): + return .succeed(promise, connection) - case .waitingForHTTP2Connection(let promise): - return .succeed(promise, connection) - - case .idle: - preconditionFailure("Invalid state: \(self.state)") + case .idle: + preconditionFailure("Invalid state: \(state)") + } } - } wrapper.complete() } @@ -509,19 +554,20 @@ extension TestConnectionCreator: HTTPConnectionRequester { } func failedToCreateHTTPConnection(_: HTTPConnectionPool.Connection.ID, error: Swift.Error) { - let wrapper = self.lock.withLock { () -> (FailPromiseWrapper) in + let wrapper: FailPromiseWrapper = self.lock + .withLockedValue { state in - switch self.state { - case .waitingForHTTP1Connection(let promise): - return .type1(promise) + switch state { + case .waitingForHTTP1Connection(let promise): + return .type1(promise) - case .waitingForHTTP2Connection(let promise): - return .type2(promise) + case .waitingForHTTP2Connection(let promise): + return .type2(promise) - case .idle: - preconditionFailure("Invalid state: \(self.state)") + case .idle: + preconditionFailure("Invalid state: \(state)") + } } - } wrapper.fail(error) } @@ -530,76 +576,78 @@ extension TestConnectionCreator: HTTPConnectionRequester { } } -class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { +final class TestHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { var hitStreamClosed: Int { - self.lock.withLock { self._hitStreamClosed } + self.lock.withLockedValue { $0.hitStreamClosed } } var hitGoAwayReceived: Int { - self.lock.withLock { self._hitGoAwayReceived } + self.lock.withLockedValue { $0.hitGoAwayReceived } } var hitConnectionClosed: Int { - self.lock.withLock { self._hitConnectionClosed } + self.lock.withLockedValue { $0.hitConnectionClosed } } var maxStreamSetting: Int { - self.lock.withLock { self._maxStreamSetting } + self.lock.withLockedValue { $0.maxStreamSetting } } - private let lock = NIOLock() - private var _hitStreamClosed: Int = 0 - private var _hitGoAwayReceived: Int = 0 - private var _hitConnectionClosed: Int = 0 - private var _maxStreamSetting: Int = 100 + private let lock = NIOLockedValueBox(.init()) + private struct Counts { + var hitStreamClosed: Int = 0 + var hitGoAwayReceived: Int = 0 + var hitConnectionClosed: Int = 0 + var maxStreamSetting: Int = 100 + } init() {} - func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) {} + func http2Connection(_: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) {} - func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) { - self.lock.withLock { - self._hitStreamClosed += 1 + func http2ConnectionStreamClosed(_: HTTPConnectionPool.Connection.ID, availableStreams: Int) { + self.lock.withLockedValue { + $0.hitStreamClosed += 1 } } - func http2ConnectionGoAwayReceived(_: HTTP2Connection) { - self.lock.withLock { - self._hitGoAwayReceived += 1 + func http2ConnectionGoAwayReceived(_: HTTPConnectionPool.Connection.ID) { + self.lock.withLockedValue { + $0.hitGoAwayReceived += 1 } } - func http2ConnectionClosed(_: HTTP2Connection) { - self.lock.withLock { - self._hitConnectionClosed += 1 + func http2ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { + self.lock.withLockedValue { + $0.hitConnectionClosed += 1 } } } final class EmptyHTTP2ConnectionDelegate: HTTP2ConnectionDelegate { - func http2Connection(_: HTTP2Connection, newMaxStreamSetting: Int) { + func http2Connection(_: HTTPConnectionPool.Connection.ID, newMaxStreamSetting: Int) { preconditionFailure("Unimplemented") } - func http2ConnectionStreamClosed(_: HTTP2Connection, availableStreams: Int) { + func http2ConnectionStreamClosed(_: HTTPConnectionPool.Connection.ID, availableStreams: Int) { preconditionFailure("Unimplemented") } - func http2ConnectionGoAwayReceived(_: HTTP2Connection) { + func http2ConnectionGoAwayReceived(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } - func http2ConnectionClosed(_: HTTP2Connection) { + func http2ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } } final class EmptyHTTP1ConnectionDelegate: HTTP1ConnectionDelegate { - func http1ConnectionReleased(_: HTTP1Connection) { + func http1ConnectionReleased(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } - func http1ConnectionClosed(_: HTTP1Connection) { + func http1ConnectionClosed(_: HTTPConnectionPool.Connection.ID) { preconditionFailure("Unimplemented") } } diff --git a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift index 611e31457..f2b56daa0 100644 --- a/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTP2IdleHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded import NIOHTTP2 import XCTest +@testable import AsyncHTTPClient + class HTTP2IdleHandlerTests: XCTestCase { func testReceiveSettingsWithMaxConcurrentStreamSetting() { let delegate = MockHTTP2IdleHandlerDelegate() @@ -26,7 +27,10 @@ class HTTP2IdleHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [idleHandler]) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + let settingsFrame = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)])) + ) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) XCTAssertEqual(delegate.maxStreams, 10) @@ -41,7 +45,11 @@ class HTTP2IdleHandlerTests: XCTestCase { let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([]))) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) - XCTAssertEqual(delegate.maxStreams, 100, "Expected to assume 100 maxConcurrentConnection, if no setting was present") + XCTAssertEqual( + delegate.maxStreams, + 100, + "Expected to assume 100 maxConcurrentConnection, if no setting was present" + ) } func testEmptySettingsDontOverwriteMaxConcurrentStreamSetting() { @@ -50,7 +58,10 @@ class HTTP2IdleHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [idleHandler]) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + let settingsFrame = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)])) + ) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) XCTAssertEqual(delegate.maxStreams, 10) @@ -66,12 +77,18 @@ class HTTP2IdleHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [idleHandler]) XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) - let settingsFrame = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)]))) + let settingsFrame = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 10)])) + ) XCTAssertEqual(delegate.maxStreams, nil) XCTAssertNoThrow(try embedded.writeInbound(settingsFrame)) XCTAssertEqual(delegate.maxStreams, 10) - let emptySettings = HTTP2Frame(streamID: 0, payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 20)]))) + let emptySettings = HTTP2Frame( + streamID: 0, + payload: .settings(.settings([.init(parameter: .maxConcurrentStreams, value: 20)])) + ) XCTAssertNoThrow(try embedded.writeInbound(emptySettings)) XCTAssertEqual(delegate.maxStreams, 20) } @@ -83,7 +100,10 @@ class HTTP2IdleHandlerTests: XCTestCase { XCTAssertNoThrow(try embedded.connect(to: .makeAddressResolvingHost("localhost", port: 0)).wait()) let randomStreamID = HTTP2StreamID((0.. + + var receivedMessages: [Message] { + get { + self.messages.value.received + } + set { + self.messages.value.received = newValue + } + } + var sentMessages: [Message] { + get { + self.messages.value.sent + } + set { + self.messages.value.sent = newValue + } + } private let eventLoop: EventLoop private let randoEL: EventLoop init(expectedEventLoop: EventLoop, randomOtherEventLoop: EventLoop) { self.eventLoop = expectedEventLoop self.randoEL = randomOtherEventLoop + self.messages = .makeBoxSendingValue(Messages(), eventLoop: expectedEventLoop) } func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { - self.eventLoop.assertInEventLoop() self.sentMessages.append(.sentRequestHead(head)) } func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { - self.eventLoop.assertInEventLoop() self.sentMessages.append(.sentRequestPart(part)) } func didSendRequest(task: HTTPClient.Task) { - self.eventLoop.assertInEventLoop() self.sentMessages.append(.sentRequest) } func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.eventLoop.assertInEventLoop() self.receivedMessages.append(.error(error)) } - public func didReceiveHead(task: HTTPClient.Task, - _ head: HTTPResponseHead) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() + public func didReceiveHead( + task: HTTPClient.Task, + _ head: HTTPResponseHead + ) -> EventLoopFuture { self.receivedMessages.append(.head(head)) return self.randoEL.makeSucceededFuture(()) } - func didReceiveBodyPart(task: HTTPClient.Task, - _ buffer: ByteBuffer) -> EventLoopFuture { - self.eventLoop.assertInEventLoop() + func didReceiveBodyPart( + task: HTTPClient.Task, + _ buffer: ByteBuffer + ) -> EventLoopFuture { self.receivedMessages.append(.bodyPart(buffer)) return self.randoEL.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Response { - self.eventLoop.assertInEventLoop() - return (self.receivedMessages, self.sentMessages) + (self.receivedMessages, self.sentMessages) } } @@ -223,7 +263,7 @@ class HTTPClientInternalTests: XCTestCase { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) } - let body: HTTPClient.Body = .stream(length: 8) { writer in + let body: HTTPClient.Body = .stream(contentLength: 8) { writer in let buffer = ByteBuffer(string: "1234") return writer.write(.byteBuffer(buffer)).flatMap { let buffer = ByteBuffer(string: "4321") @@ -231,22 +271,38 @@ class HTTPClientInternalTests: XCTestCase { } } - let request = try Request(url: "http://127.0.0.1:\(server.serverPort)/custom", - body: body) + let request = try Request( + url: "http://127.0.0.1:\(server.serverPort)/custom", + body: body + ) let delegate = Delegate(expectedEventLoop: delegateEL, randomOtherEventLoop: randoEL) - let future = httpClient.execute(request: request, - delegate: delegate, - eventLoop: .init(.testOnly_exact(channelOn: channelEL, - delegateOn: delegateEL))).futureResult - - XCTAssertNoThrow(try server.readInbound()) // .head - XCTAssertNoThrow(try server.readInbound()) // .body - XCTAssertNoThrow(try server.readInbound()) // .end + let future = httpClient.execute( + request: request, + delegate: delegate, + eventLoop: .init( + .testOnly_exact( + channelOn: channelEL, + delegateOn: delegateEL + ) + ) + ).futureResult + + XCTAssertNoThrow(try server.readInbound()) // .head + XCTAssertNoThrow(try server.readInbound()) // .body + XCTAssertNoThrow(try server.readInbound()) // .end // Send 3 parts, but only one should be received until the future is complete - XCTAssertNoThrow(try server.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), - status: .ok, - headers: HTTPHeaders([("Transfer-Encoding", "chunked")]))))) + XCTAssertNoThrow( + try server.writeOutbound( + .head( + .init( + version: .init(major: 1, minor: 1), + status: .ok, + headers: HTTPHeaders([("Transfer-Encoding", "chunked")]) + ) + ) + ) + ) let buffer = ByteBuffer(string: "1234") XCTAssertNoThrow(try server.writeOutbound(.body(.byteBuffer(buffer)))) XCTAssertNoThrow(try server.writeOutbound(.end(nil))) @@ -278,7 +334,7 @@ class HTTPClientInternalTests: XCTestCase { switch sentMessages.dropFirst(3).first { case .some(.sentRequest): - () // OK + () // OK default: XCTFail("wrong message") } @@ -316,7 +372,10 @@ class HTTPClientInternalTests: XCTestCase { let el = group.next() let req1 = client.execute(request: request, eventLoop: .delegate(on: el)) let req2 = client.execute(request: request, eventLoop: .delegateAndChannel(on: el)) - let req3 = client.execute(request: request, eventLoop: .init(.testOnly_exact(channelOn: el, delegateOn: el))) + let req3 = client.execute( + request: request, + eventLoop: .init(.testOnly_exact(channelOn: el, delegateOn: el)) + ) XCTAssert(req1.eventLoop === el) XCTAssert(req2.eventLoop === el) XCTAssert(req3.eventLoop === el) @@ -335,8 +394,8 @@ class HTTPClientInternalTests: XCTestCase { _ = httpClient.get(url: "http://localhost:\(server.serverPort)/wait") - XCTAssertNoThrow(try server.readInbound()) // .head - XCTAssertNoThrow(try server.readInbound()) // .end + XCTAssertNoThrow(try server.readInbound()) // .head + XCTAssertNoThrow(try server.readInbound()) // .end do { try httpClient.syncShutdown(requiresCleanClose: true) @@ -366,7 +425,7 @@ class HTTPClientInternalTests: XCTestCase { let el2 = group.next() XCTAssert(el1 !== el2) - let body: HTTPClient.Body = .stream(length: 8) { writer in + let body: HTTPClient.Body = .stream(contentLength: 8) { writer in XCTAssert(el1.inEventLoop) let buffer = ByteBuffer(string: "1234") return writer.write(.byteBuffer(buffer)).flatMap { @@ -376,10 +435,16 @@ class HTTPClientInternalTests: XCTestCase { } } let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/post", method: .POST, body: body) - let response = httpClient.execute(request: request, - delegate: ResponseAccumulator(request: request), - eventLoop: HTTPClient.EventLoopPreference(.testOnly_exact(channelOn: el2, - delegateOn: el1))) + let response = httpClient.execute( + request: request, + delegate: ResponseAccumulator(request: request), + eventLoop: HTTPClient.EventLoopPreference( + .testOnly_exact( + channelOn: el2, + delegateOn: el1 + ) + ) + ) XCTAssert(el1 === response.eventLoop) XCTAssertNoThrow(try response.wait()) } @@ -400,17 +465,25 @@ class HTTPClientInternalTests: XCTestCase { let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)//get") let delegate = ResponseAccumulator(request: request) - let task = client.execute(request: request, delegate: delegate, eventLoop: .init(.testOnly_exact(channelOn: el1, delegateOn: el2))) + let task = client.execute( + request: request, + delegate: delegate, + eventLoop: .init(.testOnly_exact(channelOn: el1, delegateOn: el2)) + ) XCTAssertTrue(task.futureResult.eventLoop === el2) XCTAssertNoThrow(try task.wait()) } func testConnectErrorCalloutOnCorrectEL() throws { - class TestDelegate: HTTPClientResponseDelegate { + final class TestDelegate: HTTPClientResponseDelegate { typealias Response = Void let expectedEL: EventLoop - var receivedError: Bool = false + let _receivedError = NIOLockedValueBox(false) + + var receivedError: Bool { + self._receivedError.withLockedValue { $0 } + } init(expectedEL: EventLoop) { self.expectedEL = expectedEL @@ -419,7 +492,7 @@ class HTTPClientInternalTests: XCTestCase { func didFinishRequest(task: HTTPClient.Task) throws {} func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.receivedError = true + self._receivedError.withLockedValue { $0 = true } XCTAssertTrue(self.expectedEL.inEventLoop) } } @@ -441,7 +514,11 @@ class HTTPClientInternalTests: XCTestCase { let request = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/get") let delegate = TestDelegate(expectedEL: el1) XCTAssertNoThrow(try httpBin.shutdown()) - let task = client.execute(request: request, delegate: delegate, eventLoop: .init(.testOnly_exact(channelOn: el2, delegateOn: el1))) + let task = client.execute( + request: request, + delegate: delegate, + eventLoop: .init(.testOnly_exact(channelOn: el2, delegateOn: el1)) + ) XCTAssertThrowsError(try task.wait()) XCTAssertTrue(delegate.receivedError) } @@ -474,10 +551,13 @@ class HTTPClientInternalTests: XCTestCase { let request6 = try Request(url: "https://127.0.0.1") XCTAssertEqual(request6.deconstructedURL.scheme, .https) - XCTAssertEqual(request6.deconstructedURL.connectionTarget, .ipAddress( - serialization: "127.0.0.1", - address: try! SocketAddress(ipAddress: "127.0.0.1", port: 443) - )) + XCTAssertEqual( + request6.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "127.0.0.1", + address: try! SocketAddress(ipAddress: "127.0.0.1", port: 443) + ) + ) XCTAssertEqual(request6.deconstructedURL.uri, "/") let request7 = try Request(url: "https://0x7F.1:9999") @@ -487,18 +567,24 @@ class HTTPClientInternalTests: XCTestCase { let request8 = try Request(url: "http://[::1]") XCTAssertEqual(request8.deconstructedURL.scheme, .http) - XCTAssertEqual(request8.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[::1]", - address: try! SocketAddress(ipAddress: "::1", port: 80) - )) + XCTAssertEqual( + request8.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[::1]", + address: try! SocketAddress(ipAddress: "::1", port: 80) + ) + ) XCTAssertEqual(request8.deconstructedURL.uri, "/") let request9 = try Request(url: "http://[763e:61d9::6ACA:3100:6274]:4242/foo/bar?baz") XCTAssertEqual(request9.deconstructedURL.scheme, .http) - XCTAssertEqual(request9.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[763e:61d9::6ACA:3100:6274]", - address: try! SocketAddress(ipAddress: "763e:61d9::6aca:3100:6274", port: 4242) - )) + XCTAssertEqual( + request9.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[763e:61d9::6ACA:3100:6274]", + address: try! SocketAddress(ipAddress: "763e:61d9::6aca:3100:6274", port: 4242) + ) + ) XCTAssertEqual(request9.deconstructedURL.uri, "/foo/bar?baz") // Some systems have quirks in their implementations of 'ntop' which cause them to write @@ -507,18 +593,24 @@ class HTTPClientInternalTests: XCTestCase { // so the serialization must be kept verbatim as it was given in the request. let request10 = try Request(url: "http://[::c0a8:1]:4242/foo/bar?baz") XCTAssertEqual(request10.deconstructedURL.scheme, .http) - XCTAssertEqual(request10.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[::c0a8:1]", - address: try! SocketAddress(ipAddress: "::c0a8:1", port: 4242) - )) + XCTAssertEqual( + request10.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[::c0a8:1]", + address: try! SocketAddress(ipAddress: "::c0a8:1", port: 4242) + ) + ) XCTAssertEqual(request10.deconstructedURL.uri, "/foo/bar?baz") let request11 = try Request(url: "http://[::192.168.0.1]:4242/foo/bar?baz") XCTAssertEqual(request11.deconstructedURL.scheme, .http) - XCTAssertEqual(request11.deconstructedURL.connectionTarget, .ipAddress( - serialization: "[::192.168.0.1]", - address: try! SocketAddress(ipAddress: "::192.168.0.1", port: 4242) - )) + XCTAssertEqual( + request11.deconstructedURL.connectionTarget, + .ipAddress( + serialization: "[::192.168.0.1]", + address: try! SocketAddress(ipAddress: "::192.168.0.1", port: 4242) + ) + ) XCTAssertEqual(request11.deconstructedURL.uri, "/foo/bar?baz") } @@ -547,7 +639,7 @@ class HTTPClientInternalTests: XCTestCase { } // Empty collection. do { - let elements: Array = [] + let elements: [Int] = [] XCTAssertTrue(elements.hasSuffix([])) XCTAssertFalse(elements.hasSuffix([0])) XCTAssertFalse(elements.hasSuffix([42])) @@ -585,7 +677,8 @@ class HTTPClientInternalTests: XCTestCase { ).futureResult } _ = try EventLoopFuture.whenAllSucceed(resultFutures, on: self.clientGroup.next()).wait() - let threadPools = delegates.map { $0.fileIOThreadPool } + + let threadPools = delegates.map { $0._fileIOThreadPool } let firstThreadPool = threadPools.first ?? nil XCTAssert(threadPools.dropFirst().allSatisfy { $0 === firstThreadPool }) } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift index be03f6a6a..4c2d24dc4 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientNIOTSTests.swift @@ -12,10 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient -#if canImport(Network) -import Network -#endif import NIOConcurrencyHelpers import NIOCore import NIOPosix @@ -23,6 +19,12 @@ import NIOSSL import NIOTransportServices import XCTest +@testable import AsyncHTTPClient + +#if canImport(Network) +import Network +#endif + class HTTPClientNIOTSTests: XCTestCase { var clientGroup: EventLoopGroup! @@ -55,11 +57,12 @@ class HTTPClientNIOTSTests: XCTestCase { guard isTestingNIOTS() else { return } let httpBin = HTTPBin(.http1_1(ssl: true)) - var config = HTTPClient.Configuration() - config.networkFrameworkWaitForConnectivity = false - config.connectionPool.retryConnectionEstablishment = false - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: config) + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) XCTAssertNoThrow(try httpBin.shutdown()) @@ -70,8 +73,10 @@ class HTTPClientNIOTSTests: XCTestCase { _ = try httpClient.get(url: "https://localhost:\(httpBin.port)/get").wait() XCTFail("This should have failed") } catch let error as HTTPClient.NWTLSError { - XCTAssert(error.status == errSSLHandshakeFail || error.status == errSSLBadCert, - "unexpected NWTLSError with status \(error.status)") + XCTAssert( + error.status == errSSLHandshakeFail || error.status == errSSLBadCert, + "unexpected NWTLSError with status \(error.status)" + ) } catch { XCTFail("Error should have been NWTLSError not \(type(of: error))") } @@ -84,12 +89,13 @@ class HTTPClientNIOTSTests: XCTestCase { guard isTestingNIOTS() else { return } #if canImport(Network) let httpBin = HTTPBin(.http1_1(ssl: false)) - var config = HTTPClient.Configuration() - config.networkFrameworkWaitForConnectivity = false - config.connectionPool.retryConnectionEstablishment = false + let config = HTTPClient.Configuration() + .enableFastFailureModeForTesting() - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: config) + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: config + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) @@ -108,9 +114,15 @@ class HTTPClientNIOTSTests: XCTestCase { guard isTestingNIOTS() else { return } #if canImport(Network) let httpBin = HTTPBin(.http1_1(ssl: false)) - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(100), - read: .milliseconds(100)))) + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init( + timeout: .init( + connect: .milliseconds(100), + read: .milliseconds(100) + ) + ) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown(requiresCleanClose: true)) @@ -140,9 +152,8 @@ class HTTPClientNIOTSTests: XCTestCase { tlsConfig.minimumTLSVersion = .tlsv11 tlsConfig.maximumTLSVersion = .tlsv1 - var clientConfig = HTTPClient.Configuration(tlsConfiguration: tlsConfig) - clientConfig.networkFrameworkWaitForConnectivity = false - clientConfig.connectionPool.retryConnectionEstablishment = false + let clientConfig = HTTPClient.Configuration(tlsConfiguration: tlsConfig) + .enableFastFailureModeForTesting() let httpClient = HTTPClient( eventLoopGroupProvider: .shared(self.clientGroup), configuration: clientConfig diff --git a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift index b0b1be1d8..54467aab7 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientRequestTests.swift @@ -13,10 +13,13 @@ //===----------------------------------------------------------------------===// import Algorithms -@testable import AsyncHTTPClient +import NIOConcurrencyHelpers import NIOCore +import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) class HTTPClientRequestTests: XCTestCase { private typealias Request = HTTPClientRequest @@ -27,36 +30,56 @@ class HTTPClientRequestTests: XCTestCase { XCTAsyncTest { var request = Request(url: "https://example.com/get") request.headers = [ - "custom-header": "custom-header-value", + "custom-header": "custom-header-value" ] var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .https, - connectionTarget: .domain(name: "example.com", port: 443), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/get", - headers: [ - "host": "example.com", - "custom-header": "custom-header-value", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .https, + connectionTarget: .domain(name: "example.com", port: 443), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/get", + headers: [ + "host": "example.com", + "custom-header": "custom-header-value", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } } + func testBasicAuth() { + XCTAsyncTest { + var request = Request(url: "https://example.com/get") + request.setBasicAuth(username: "foo", password: "bar") + var preparedRequest: PreparedRequest? + XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) + guard let preparedRequest = preparedRequest else { return } + XCTAssertEqual(preparedRequest.head.headers.first(name: "Authorization")!, "Basic Zm9vOmJhcg==") + } + } + func testUnixScheme() { XCTAsyncTest { var request = Request(url: "unix://%2Fexample%2Ffolder.sock/some_path") @@ -65,22 +88,31 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .unix, - connectionTarget: .unixSocket(path: "/some_path"), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/", - headers: ["custom-header": "custom-value"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .unix, + connectionTarget: .unixSocket(path: "/some_path"), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/", + headers: ["custom-header": "custom-value"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } @@ -94,22 +126,31 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .httpUnix, - connectionTarget: .unixSocket(path: "/example/folder.sock"), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/some_path", - headers: ["custom-header": "custom-value"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .httpUnix, + connectionTarget: .unixSocket(path: "/example/folder.sock"), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/some_path", + headers: ["custom-header": "custom-value"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } @@ -123,22 +164,31 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .httpsUnix, - connectionTarget: .unixSocket(path: "/example/folder.sock"), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/some_path", - headers: ["custom-header": "custom-value"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .httpsUnix, + connectionTarget: .unixSocket(path: "/example/folder.sock"), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/some_path", + headers: ["custom-header": "custom-value"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } @@ -151,22 +201,31 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .https, - connectionTarget: .domain(name: "example.com", port: 443), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .GET, - uri: "/get", - headers: ["host": "example.com"] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .https, + connectionTarget: .domain(name: "example.com", port: 443), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .GET, + uri: "/get", + headers: ["host": "example.com"] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) } @@ -180,25 +239,34 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "0", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "0", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) @@ -214,25 +282,34 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "0", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(0) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "0", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(0) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, ByteBuffer()) @@ -248,25 +325,34 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } @@ -282,25 +368,34 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "transfer-encoding": "chunked", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .stream - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "transfer-encoding": "chunked", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .stream + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } @@ -312,30 +407,39 @@ class HTTPClientRequestTests: XCTestCase { request.method = .POST let sequence = AnySendableSequence(ByteBuffer(string: "post body").readableBytesView) - request.body = .bytes(sequence, length: .known(9)) + request.body = .bytes(sequence, length: .known(Int64(9))) var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } @@ -351,25 +455,34 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } @@ -381,7 +494,7 @@ class HTTPClientRequestTests: XCTestCase { request.method = .POST let asyncSequence = ByteBuffer(string: "post body") .readableBytesView - .chunks(ofCount: 2) + .uncheckedSendableChunks(ofCount: 2) .async .map { ByteBuffer($0) } @@ -390,25 +503,34 @@ class HTTPClientRequestTests: XCTestCase { XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "transfer-encoding": "chunked", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .stream - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "transfer-encoding": "chunked", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .stream + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } @@ -420,34 +542,43 @@ class HTTPClientRequestTests: XCTestCase { request.method = .POST let asyncSequence = ByteBuffer(string: "post body") .readableBytesView - .chunks(ofCount: 2) + .uncheckedSendableChunks(ofCount: 2) .async .map { ByteBuffer($0) } - request.body = .stream(asyncSequence, length: .known(9)) + request.body = .stream(asyncSequence, length: .known(Int64(9))) var preparedRequest: PreparedRequest? XCTAssertNoThrow(preparedRequest = try PreparedRequest(request)) guard let preparedRequest = preparedRequest else { return } - XCTAssertEqual(preparedRequest.poolKey, .init( - scheme: .http, - connectionTarget: .domain(name: "example.com", port: 80), - tlsConfiguration: nil, - serverNameIndicatorOverride: nil - )) - XCTAssertEqual(preparedRequest.head, .init( - version: .http1_1, - method: .POST, - uri: "/post", - headers: [ - "host": "example.com", - "content-length": "9", - ] - )) - XCTAssertEqual(preparedRequest.requestFramingMetadata, .init( - connectionClose: false, - body: .fixedSize(9) - )) + XCTAssertEqual( + preparedRequest.poolKey, + .init( + scheme: .http, + connectionTarget: .domain(name: "example.com", port: 80), + tlsConfiguration: nil, + serverNameIndicatorOverride: nil + ) + ) + XCTAssertEqual( + preparedRequest.head, + .init( + version: .http1_1, + method: .POST, + uri: "/post", + headers: [ + "host": "example.com", + "content-length": "9", + ] + ) + ) + XCTAssertEqual( + preparedRequest.requestFramingMetadata, + .init( + connectionClose: false, + body: .fixedSize(9) + ) + ) guard let buffer = await XCTAssertNoThrowWithResult(try await preparedRequest.body.read()) else { return } XCTAssertEqual(buffer, .init(string: "post body")) } @@ -455,9 +586,9 @@ class HTTPClientRequestTests: XCTestCase { func testChunkingRandomAccessCollection() async throws { let body = try await HTTPClientRequest.Body.bytes( - Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + - Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + - Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) ).collect() let expectedChunks = [ @@ -471,12 +602,10 @@ class HTTPClientRequestTests: XCTestCase { func testChunkingCollection() async throws { let body = try await HTTPClientRequest.Body.bytes( - ( - String(repeating: "0", count: bagOfBytesToByteBufferConversionChunkSize) + - String(repeating: "1", count: bagOfBytesToByteBufferConversionChunkSize) + - String(repeating: "2", count: bagOfBytesToByteBufferConversionChunkSize) - ).utf8, - length: .known(bagOfBytesToByteBufferConversionChunkSize * 3) + (String(repeating: "0", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "1", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "2", count: bagOfBytesToByteBufferConversionChunkSize)).utf8, + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)) ).collect() let expectedChunks = [ @@ -491,11 +620,11 @@ class HTTPClientRequestTests: XCTestCase { func testChunkingSequenceThatDoesNotImplementWithContiguousStorageIfAvailable() async throws { let bagOfBytesToByteBufferConversionChunkSize = 8 let body = try await HTTPClientRequest.Body._bytes( - AnySequence( - Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + - Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + AnySendableSequence( + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) ), - length: .known(bagOfBytesToByteBufferConversionChunkSize * 3), + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)), bagOfBytesToByteBufferConversionChunkSize: bagOfBytesToByteBufferConversionChunkSize, byteBufferMaxSize: byteBufferMaxSize ).collect() @@ -510,20 +639,20 @@ class HTTPClientRequestTests: XCTestCase { func testChunkingSequenceFastPath() async throws { func makeBytes() -> some Sequence & Sendable { - Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + - Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + - Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) } let body = try await HTTPClientRequest.Body.bytes( makeBytes(), - length: .known(bagOfBytesToByteBufferConversionChunkSize * 3) + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)) ).collect() var firstChunk = ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) firstChunk.writeImmutableBuffer(ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize)) firstChunk.writeImmutableBuffer(ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize)) let expectedChunks = [ - firstChunk, + firstChunk ] XCTAssertEqual(body, expectedChunks) @@ -533,13 +662,13 @@ class HTTPClientRequestTests: XCTestCase { let bagOfBytesToByteBufferConversionChunkSize = 8 let byteBufferMaxSize = 16 func makeBytes() -> some Sequence & Sendable { - Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + - Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + - Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) } let body = try await HTTPClientRequest.Body._bytes( makeBytes(), - length: .known(bagOfBytesToByteBufferConversionChunkSize * 3), + length: .known(Int64(bagOfBytesToByteBufferConversionChunkSize * 3)), bagOfBytesToByteBufferConversionChunkSize: bagOfBytesToByteBufferConversionChunkSize, byteBufferMaxSize: byteBufferMaxSize ).collect() @@ -557,12 +686,13 @@ class HTTPClientRequestTests: XCTestCase { func testBodyStringChunking() throws { let body = try HTTPClient.Body.string( - String(repeating: "0", count: bagOfBytesToByteBufferConversionChunkSize) + - String(repeating: "1", count: bagOfBytesToByteBufferConversionChunkSize) + - String(repeating: "2", count: bagOfBytesToByteBufferConversionChunkSize) + String(repeating: "0", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "1", count: bagOfBytesToByteBufferConversionChunkSize) + + String(repeating: "2", count: bagOfBytesToByteBufferConversionChunkSize) ).collect().wait() let expectedChunks = [ + ByteBuffer(), // We're currently emitting an empty chunk first. ByteBuffer(repeating: UInt8(ascii: "0"), count: bagOfBytesToByteBufferConversionChunkSize), ByteBuffer(repeating: UInt8(ascii: "1"), count: bagOfBytesToByteBufferConversionChunkSize), ByteBuffer(repeating: UInt8(ascii: "2"), count: bagOfBytesToByteBufferConversionChunkSize), @@ -573,12 +703,13 @@ class HTTPClientRequestTests: XCTestCase { func testBodyChunkingRandomAccessCollection() throws { let body = try HTTPClient.Body.bytes( - Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + - Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + - Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) + Array(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize) + + Array(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize) ).collect().wait() let expectedChunks = [ + ByteBuffer(), // We're currently emitting an empty chunk first. ByteBuffer(repeating: 0, count: bagOfBytesToByteBufferConversionChunkSize), ByteBuffer(repeating: 1, count: bagOfBytesToByteBufferConversionChunkSize), ByteBuffer(repeating: 2, count: bagOfBytesToByteBufferConversionChunkSize), @@ -599,23 +730,23 @@ extension HTTPClient.Body { func collect() -> EventLoopFuture<[ByteBuffer]> { let eelg = EmbeddedEventLoopGroup(loops: 1) let el = eelg.next() - var body = [ByteBuffer]() + let body = NIOLockedValueBox<[ByteBuffer]>([]) let writer = StreamWriter { switch $0 { case .byteBuffer(let byteBuffer): - body.append(byteBuffer) + body.withLockedValue { $0.append(byteBuffer) } case .fileRegion: fatalError("file region not supported") } return el.makeSucceededVoidFuture() } - return self.stream(writer).map { _ in body } + return self.stream(writer).map { _ in body.withLockedValue { $0 } } } } private struct LengthMismatch: Error { - var announcedLength: Int - var actualLength: Int + var announcedLength: Int64 + var actualLength: Int64 } @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) @@ -631,20 +762,58 @@ extension Optional where Wrapped == HTTPClientRequest.Prepared.Body { case .sequence(let announcedLength, _, let generate): let buffer = generate(ByteBufferAllocator()) if case .known(let announcedLength) = announcedLength, - announcedLength != buffer.readableBytes { - throw LengthMismatch(announcedLength: announcedLength, actualLength: buffer.readableBytes) + announcedLength != Int64(buffer.readableBytes) + { + throw LengthMismatch(announcedLength: announcedLength, actualLength: Int64(buffer.readableBytes)) } return buffer - case .asyncSequence(length: let announcedLength, let generate): + case .asyncSequence(length: let announcedLength, let makeAsyncIterator): var accumulatedBuffer = ByteBuffer() + let generate = makeAsyncIterator() while var buffer = try await generate(ByteBufferAllocator()) { accumulatedBuffer.writeBuffer(&buffer) } if case .known(let announcedLength) = announcedLength, - announcedLength != accumulatedBuffer.readableBytes { - throw LengthMismatch(announcedLength: announcedLength, actualLength: accumulatedBuffer.readableBytes) + announcedLength != Int64(accumulatedBuffer.readableBytes) + { + throw LengthMismatch( + announcedLength: announcedLength, + actualLength: Int64(accumulatedBuffer.readableBytes) + ) } return accumulatedBuffer } } } + +// swift-algorithms hasn't adopted Sendable yet. By inspection ChunksOfCountCollection should be +// Sendable assuming the underlying collection is. This wrapper allows us to avoid a blanket +// preconcurrency import of the Algorithms module. +struct UncheckedSendableChunksOfCountCollection: Collection, @unchecked Sendable +where Base: Sendable { + typealias Element = Base.SubSequence + typealias Index = ChunksOfCountCollection.Index + + private let underlying: ChunksOfCountCollection + + init(_ underlying: ChunksOfCountCollection) { + self.underlying = underlying + } + + var startIndex: Index { self.underlying.startIndex } + var endIndex: Index { self.underlying.endIndex } + + subscript(position: Index) -> Base.SubSequence { + self.underlying[position] + } + + func index(after i: Index) -> Index { + self.underlying.index(after: i) + } +} + +extension Collection where Self: Sendable { + func uncheckedSendableChunks(ofCount count: Int) -> UncheckedSendableChunksOfCountCollection { + UncheckedSendableChunksOfCountCollection(self.chunks(ofCount: count)) + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift index 2c6c9afac..7dcc4efe6 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientResponseTests.swift @@ -12,25 +12,39 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore +import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class HTTPClientResponseTests: XCTestCase { func testSimpleResponse() { - let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .ok) + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .GET, + headers: ["content-length": "1025"], + status: .ok + ) XCTAssertEqual(response, 1025) } func testSimpleResponseNotModified() { - let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "1025"], status: .notModified) + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .GET, + headers: ["content-length": "1025"], + status: .notModified + ) XCTAssertEqual(response, 0) } func testSimpleResponseHeadRequestMethod() { - let response = HTTPClientResponse.expectedContentLength(requestMethod: .HEAD, headers: ["content-length": "1025"], status: .ok) + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .HEAD, + headers: ["content-length": "1025"], + status: .ok + ) XCTAssertEqual(response, 0) } @@ -40,7 +54,11 @@ final class HTTPClientResponseTests: XCTestCase { } func testResponseInvalidInteger() { - let response = HTTPClientResponse.expectedContentLength(requestMethod: .GET, headers: ["content-length": "none"], status: .ok) + let response = HTTPClientResponse.expectedContentLength( + requestMethod: .GET, + headers: ["content-length": "none"], + status: .ok + ) XCTAssertEqual(response, nil) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift index 2d37b1387..f9917c885 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift @@ -12,13 +12,13 @@ // //===----------------------------------------------------------------------===// -import AsyncHTTPClient import Atomics import Foundation import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOFoundationCompat import NIOHPACK import NIOHTTP1 import NIOHTTP2 @@ -28,10 +28,19 @@ import NIOSSL import NIOTLS import NIOTransportServices import XCTest -#if canImport(Darwin) + +@testable import AsyncHTTPClient + +#if canImport(xlocale) +import xlocale +#elseif canImport(locale_h) +import locale_h +#elseif canImport(Darwin) import Darwin #elseif canImport(Musl) import Musl +#elseif canImport(Android) +import Android #elseif canImport(Glibc) import Glibc #endif @@ -48,7 +57,8 @@ func isTestingNIOTS() -> Bool { func getDefaultEventLoopGroup(numberOfThreads: Int) -> EventLoopGroup { #if canImport(Network) if #available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *), - isTestingNIOTS() { + isTestingNIOTS() + { return NIOTSEventLoopGroup(loopCount: numberOfThreads, defaultQoS: .default) } #endif @@ -85,15 +95,13 @@ func withCLocaleSetToGerman(_ body: () throws -> Void) throws { try body() } -class TestHTTPDelegate: HTTPClientResponseDelegate { +final class TestHTTPDelegate: HTTPClientResponseDelegate { typealias Response = Void init(backpressureEventLoop: EventLoop? = nil) { - self.backpressureEventLoop = backpressureEventLoop + self.state = NIOLockedValueBox(MutableState(backpressureEventLoop: backpressureEventLoop)) } - var backpressureEventLoop: EventLoop? - enum State { case idle case head(HTTPResponseHead) @@ -102,77 +110,96 @@ class TestHTTPDelegate: HTTPClientResponseDelegate { case error(Error) } - var state = State.idle + struct MutableState: Sendable { + var state: State = .idle + var backpressureEventLoop: EventLoop? + } + + let state: NIOLockedValueBox func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.state = .head(head) - return (self.backpressureEventLoop ?? task.eventLoop).makeSucceededFuture(()) + let eventLoop = self.state.withLockedValue { + $0.state = .head(head) + return ($0.backpressureEventLoop ?? task.eventLoop) + } + + return eventLoop.makeSucceededVoidFuture() } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - switch self.state { - case .head(let head): - self.state = .body(head, buffer) - case .body(let head, var body): - var buffer = buffer - body.writeBuffer(&buffer) - self.state = .body(head, body) - default: - preconditionFailure("expecting head or body") + let eventLoop = self.state.withLockedValue { + switch $0.state { + case .head(let head): + $0.state = .body(head, buffer) + case .body(let head, var body): + var buffer = buffer + body.writeBuffer(&buffer) + $0.state = .body(head, body) + default: + preconditionFailure("expecting head or body") + } + return ($0.backpressureEventLoop ?? task.eventLoop) } - return (self.backpressureEventLoop ?? task.eventLoop).makeSucceededFuture(()) + + return eventLoop.makeSucceededVoidFuture() } func didFinishRequest(task: HTTPClient.Task) throws {} } -class CountingDelegate: HTTPClientResponseDelegate { +final class CountingDelegate: HTTPClientResponseDelegate { typealias Response = Int - var count = 0 + private let _count = NIOLockedValueBox(0) func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { let str = buffer.getString(at: 0, length: buffer.readableBytes) if str?.starts(with: "id:") ?? false { - self.count += 1 + self._count.withLockedValue { $0 += 1 } } return task.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Int { - return self.count + self._count.withLockedValue { $0 } } } -class DelayOnHeadDelegate: HTTPClientResponseDelegate { +final class DelayOnHeadDelegate: HTTPClientResponseDelegate { typealias Response = ByteBuffer let eventLoop: EventLoop - let didReceiveHead: (HTTPResponseHead, EventLoopPromise) -> Void + let didReceiveHead: @Sendable (HTTPResponseHead, EventLoopPromise) -> Void - private var data: ByteBuffer - - private var mayReceiveData = false + struct State: Sendable { + var data: ByteBuffer + var mayReceiveData = false + var expectError = false + } - private var expectError = false + private let state: NIOLockedValueBox - init(eventLoop: EventLoop, didReceiveHead: @escaping (HTTPResponseHead, EventLoopPromise) -> Void) { + init(eventLoop: EventLoop, didReceiveHead: @escaping @Sendable (HTTPResponseHead, EventLoopPromise) -> Void) { self.eventLoop = eventLoop self.didReceiveHead = didReceiveHead - self.data = ByteBuffer() + self.state = NIOLockedValueBox(State(data: ByteBuffer())) } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - XCTAssertFalse(self.mayReceiveData) - XCTAssertFalse(self.expectError) + self.state.withLockedValue { + XCTAssertFalse($0.mayReceiveData) + XCTAssertFalse($0.expectError) + } let promise = self.eventLoop.makePromise(of: Void.self) - promise.futureResult.whenComplete { - switch $0 { - case .success: - self.mayReceiveData = true - case .failure: - self.expectError = true + promise.futureResult.whenComplete { result in + self.state.withLockedValue { state in + switch result { + case .success: + state.mayReceiveData = true + case .failure: + state.expectError = true + } } } @@ -181,20 +208,26 @@ class DelayOnHeadDelegate: HTTPClientResponseDelegate { } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - XCTAssertTrue(self.mayReceiveData) - XCTAssertFalse(self.expectError) - self.data.writeImmutableBuffer(buffer) + self.state.withLockedValue { + XCTAssertTrue($0.mayReceiveData) + XCTAssertFalse($0.expectError) + $0.data.writeImmutableBuffer(buffer) + } return self.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Response { - XCTAssertTrue(self.mayReceiveData) - XCTAssertFalse(self.expectError) - return self.data + self.state.withLockedValue { + XCTAssertTrue($0.mayReceiveData) + XCTAssertFalse($0.expectError) + return $0.data + } } func didReceiveError(task: HTTPClient.Task, _ error: Error) { - XCTAssertTrue(self.expectError) + self.state.withLockedValue { + XCTAssertTrue($0.expectError) + } } } @@ -215,8 +248,8 @@ enum TemporaryFileHelpers { } else { return "/tmp" } - #endif // os - #endif // targetEnvironment + #endif // os + #endif // targetEnvironment } private static func openTemporaryFile() -> (CInt, String) { @@ -236,8 +269,10 @@ enum TemporaryFileHelpers { /// /// If the temporary directory is too long to store a UNIX domain socket path, it will `chdir` into the temporary /// directory and return a short-enough path. The iOS simulator is known to have too long paths. - internal static func withTemporaryUnixDomainSocketPathName(directory: String = temporaryDirectory, - _ body: (String) throws -> T) throws -> T { + internal static func withTemporaryUnixDomainSocketPathName( + directory: String = temporaryDirectory, + _ body: (String) throws -> T + ) throws -> T { // this is racy but we're trying to create the shortest possible path so we can't add a directory... let (fd, path) = self.openTemporaryFile() close(fd) @@ -252,17 +287,21 @@ enum TemporaryFileHelpers { shortEnoughPath = path restoreSavedCWD = false } catch SocketAddressError.unixDomainSocketPathTooLong { - FileManager.default.changeCurrentDirectoryPath(URL(fileURLWithPath: path).deletingLastPathComponent().absoluteString) + _ = FileManager.default.changeCurrentDirectoryPath( + URL(fileURLWithPath: path).deletingLastPathComponent().absoluteString + ) shortEnoughPath = URL(fileURLWithPath: path).lastPathComponent restoreSavedCWD = true - print("WARNING: Path '\(path)' could not be used as UNIX domain socket path, using chdir & '\(shortEnoughPath)'") + print( + "WARNING: Path '\(path)' could not be used as UNIX domain socket path, using chdir & '\(shortEnoughPath)'" + ) } defer { if FileManager.default.fileExists(atPath: path) { try? FileManager.default.removeItem(atPath: path) } if restoreSavedCWD { - FileManager.default.changeCurrentDirectoryPath(saveCurrentDirectory) + _ = FileManager.default.changeCurrentDirectoryPath(saveCurrentDirectory) } } return try body(shortEnoughPath) @@ -303,11 +342,11 @@ enum TemporaryFileHelpers { } internal static func fileSize(path: String) throws -> Int? { - return try FileManager.default.attributesOfItem(atPath: path)[.size] as? Int + try FileManager.default.attributesOfItem(atPath: path)[.size] as? Int } internal static func fileExists(path: String) -> Bool { - return FileManager.default.fileExists(atPath: path) + FileManager.default.fileExists(atPath: path) } } @@ -320,9 +359,11 @@ enum TestTLS { ) } -internal final class HTTPBin where +internal final class HTTPBin: Sendable +where RequestHandler.InboundIn == HTTPServerRequestPart, - RequestHandler.OutboundOut == HTTPServerResponsePart { + RequestHandler.OutboundOut == HTTPServerResponsePart +{ enum BindTarget { case unixDomainSocket(String) case localhostIPv4RandomPort @@ -361,10 +402,7 @@ internal final class HTTPBin where var httpSettings: HTTP2Settings { switch self { case .http1_1, .http2(_, _, nil), .refuse: - return [ - HTTP2Setting(parameter: .maxConcurrentStreams, value: 10), - HTTP2Setting(parameter: .maxHeaderListSize, value: HPACKDecoder.defaultMaxHeaderListSize), - ] + return HTTP2Connection.defaultSettings case .http2(_, _, .some(let customSettings)): return customSettings } @@ -392,19 +430,23 @@ internal final class HTTPBin where private let activeConnCounterHandler: ConnectionsCountHandler var activeConnections: Int { - return self.activeConnCounterHandler.currentlyActiveConnections + self.activeConnCounterHandler.currentlyActiveConnections } var createdConnections: Int { - return self.activeConnCounterHandler.createdConnections + self.activeConnCounterHandler.createdConnections } var port: Int { - return Int(self.serverChannel.localAddress!.port!) + self.serverChannel.withLockedValue { + Int($0!.localAddress!.port!) + } } var socketAddress: SocketAddress { - return self.serverChannel.localAddress! + self.serverChannel.withLockedValue { + $0!.localAddress! + } } var baseURL: String { @@ -432,16 +474,17 @@ internal final class HTTPBin where private let mode: Mode private let sslContext: NIOSSLContext? - private var serverChannel: Channel! + private let serverChannel = NIOLockedValueBox(nil) private let isShutdown = ManagedAtomic(false) - private let handlerFactory: (Int) -> (RequestHandler) + private let handlerFactory: @Sendable (Int) -> (RequestHandler) init( _ mode: Mode = .http1_1(ssl: false, compress: false), proxy: Proxy = .none, bindTarget: BindTarget = .localhostIPv4RandomPort, reusePort: Bool = false, - handlerFactory: @escaping (Int) -> (RequestHandler) + trafficShapingTargetBytesPerSecond: Int? = nil, + handlerFactory: @escaping @Sendable (Int) -> (RequestHandler) ) { self.mode = mode self.sslContext = HTTPBin.sslContext(for: mode) @@ -461,12 +504,22 @@ internal final class HTTPBin where let connectionIDAtomic = ManagedAtomic(0) - self.serverChannel = try! ServerBootstrap(group: self.group) + let serverChannel = try! ServerBootstrap(group: self.group) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) - .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: reusePort ? 1 : 0) - .serverChannelInitializer { channel in - channel.pipeline.addHandler(self.activeConnCounterHandler) + .serverChannelOption( + ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), + value: reusePort ? 1 : 0 + ) + .serverChannelInitializer { [activeConnCounterHandler] channel in + channel.pipeline.addHandler(activeConnCounterHandler) }.childChannelInitializer { channel in + if let trafficShapingTargetBytesPerSecond = trafficShapingTargetBytesPerSecond { + try! channel.pipeline.syncOperations.addHandler( + BasicInboundTrafficShapingHandler( + targetBytesPerSecond: trafficShapingTargetBytesPerSecond + ) + ) + } do { let connectionID = connectionIDAtomic.loadThenWrappingIncrement(ordering: .relaxed) @@ -502,6 +555,7 @@ internal final class HTTPBin where return channel.eventLoop.makeFailedFuture(error) } }.bind(to: socketAddress).wait() + self.serverChannel.withLockedValue { $0 = serverChannel } } private func syncAddHTTPProxyHandlers( @@ -520,12 +574,12 @@ internal final class HTTPBin where try sync.addHandler(requestDecoder) try sync.addHandler(proxySimulator) - promise.futureResult.flatMap { _ in - channel.pipeline.removeHandler(proxySimulator) + promise.futureResult.assumeIsolated().flatMap { _ in + channel.pipeline.syncOperations.removeHandler(proxySimulator) }.flatMap { _ in - channel.pipeline.removeHandler(responseEncoder) + channel.pipeline.syncOperations.removeHandler(responseEncoder) }.flatMap { _ in - channel.pipeline.removeHandler(requestDecoder) + channel.pipeline.syncOperations.removeHandler(requestDecoder) }.whenComplete { result in switch result { case .failure: @@ -591,6 +645,7 @@ internal final class HTTPBin where let multiplexer = HTTP2StreamMultiplexer( mode: .server, channel: channel, + targetWindowSize: 16 * 1024 * 1024, // 16 MiB inboundStreamInitializer: { channel in do { let sync = channel.pipeline.syncOperations @@ -628,8 +683,8 @@ internal final class HTTPBin where } } + try channel.pipeline.syncOperations.addHandler(sslHandler) try channel.pipeline.syncOperations.addHandler(alpnHandler) - try channel.pipeline.syncOperations.addHandler(sslHandler, position: .before(alpnHandler)) } func shutdown() throws { @@ -647,9 +702,16 @@ extension HTTPBin where RequestHandler == HTTPBinHandler { _ mode: Mode = .http1_1(ssl: false, compress: false), proxy: Proxy = .none, bindTarget: BindTarget = .localhostIPv4RandomPort, - reusePort: Bool = false + reusePort: Bool = false, + trafficShapingTargetBytesPerSecond: Int? = nil ) { - self.init(mode, proxy: proxy, bindTarget: bindTarget, reusePort: reusePort) { HTTPBinHandler(connectionID: $0) } + self.init( + mode, + proxy: proxy, + bindTarget: bindTarget, + reusePort: reusePort, + trafficShapingTargetBytesPerSecond: trafficShapingTargetBytesPerSecond + ) { HTTPBinHandler(connectionID: $0) } } } @@ -672,7 +734,11 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { init(promise: EventLoopPromise, expectedAuthorization: String?) { self.promise = promise self.expectedAuthorization = expectedAuthorization - self.head = HTTPResponseHead(version: .init(major: 1, minor: 1), status: .ok, headers: .init([("Content-Length", "0")])) + self.head = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .ok, + headers: .init([("Content-Length", "0")]) + ) } func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -686,7 +752,8 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { if let expectedAuthorization = self.expectedAuthorization { guard let authorization = head.headers["proxy-authorization"].first, - expectedAuthorization == authorization else { + expectedAuthorization == authorization + else { self.head.status = .proxyAuthenticationRequired return } @@ -710,12 +777,31 @@ final class HTTPProxySimulator: ChannelInboundHandler, RemovableChannelHandler { internal struct HTTPResponseBuilder { var head: HTTPResponseHead var body: ByteBuffer? + var requestBodyByteCount: Int + let responseBodyIsRequestBodyByteCount: Bool - init(_ version: HTTPVersion = HTTPVersion(major: 1, minor: 1), status: HTTPResponseStatus, headers: HTTPHeaders = HTTPHeaders()) { + init( + _ version: HTTPVersion = HTTPVersion(major: 1, minor: 1), + status: HTTPResponseStatus, + headers: HTTPHeaders = HTTPHeaders(), + responseBodyIsRequestBodyByteCount: Bool = false + ) { self.head = HTTPResponseHead(version: version, status: status, headers: headers) + self.requestBodyByteCount = 0 + self.responseBodyIsRequestBodyByteCount = responseBodyIsRequestBodyByteCount } mutating func add(_ part: ByteBuffer) { + self.requestBodyByteCount += part.readableBytes + guard !self.responseBodyIsRequestBodyByteCount else { + if self.body == nil { + self.body = ByteBuffer() + self.body!.reserveCapacity(100) + } + self.body!.clear() + self.body!.writeString("\(self.requestBodyByteCount)") + return + } if var body = body { var part = part body.writeBuffer(&part) @@ -763,8 +849,10 @@ internal final class HTTPBinHandler: ChannelInboundHandler { for header in head.headers { let needle = "x-send-back-header-" if header.name.lowercased().starts(with: needle) { - self.responseHeaders.add(name: String(header.name.dropFirst(needle.count)), - value: header.value) + self.responseHeaders.add( + name: String(header.name.dropFirst(needle.count)), + value: header.value + ) } } } @@ -777,7 +865,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { headers = HTTPHeaders() } - context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + context.write( + wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) for i in 0..<10 { let msg = "id: \(i)" var buf = context.channel.allocator.buffer(capacity: msg.count) @@ -792,7 +885,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { // This tests receiving chunks very fast: please do not insert delays here! let headers = HTTPHeaders([("Transfer-Encoding", "chunked")]) - context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + context.write( + self.wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) for i in 0..<10 { let msg = "id: \(i)" var buf = context.channel.allocator.buffer(capacity: msg.count) @@ -807,7 +905,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { // This tests receiving a lot of tiny chunks: they must all be sent in a single flush or the test doesn't work. let headers = HTTPHeaders([("Transfer-Encoding", "chunked")]) - context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + context.write( + self.wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) let message = ByteBuffer(integer: UInt8(ascii: "a")) // This number (10k) is load-bearing and a bit magic: it has been experimentally verified as being sufficient to blow the stack @@ -867,6 +970,13 @@ internal final class HTTPBinHandler: ChannelInboundHandler { } self.resps.append(HTTPResponseBuilder(status: .ok)) return + case "/post-respond-with-byte-count": + if req.method != .POST { + self.resps.append(HTTPResponseBuilder(status: .methodNotAllowed)) + return + } + self.resps.append(HTTPResponseBuilder(status: .ok, responseBodyIsRequestBodyByteCount: true)) + return case "/redirect/302": var headers = self.responseHeaders headers.add(name: "location", value: "/ok") @@ -927,9 +1037,12 @@ internal final class HTTPBinHandler: ChannelInboundHandler { context.close(promise: nil) return case "/custom": - context.writeAndFlush(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), promise: nil) + context.writeAndFlush( + wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok))), + promise: nil + ) return - case "/events/10/1": // TODO: parse path + case "/events/10/1": // TODO: parse path self.writeEvents(context: context) return case "/events/10/content-length": @@ -953,10 +1066,20 @@ internal final class HTTPBinHandler: ChannelInboundHandler { case "/content-length-without-body": var headers = self.responseHeaders headers.replaceOrAdd(name: "content-length", value: "1234") - context.writeAndFlush(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil) + context.writeAndFlush( + wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers)) + ), + promise: nil + ) return default: - context.write(wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound))), promise: nil) + context.write( + wrapOutboundOut( + .head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .notFound)) + ), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) return } @@ -975,32 +1098,41 @@ internal final class HTTPBinHandler: ChannelInboundHandler { response.head.headers.add(contentsOf: self.responseHeaders) context.write(wrapOutboundOut(.head(response.head)), promise: nil) if let body = response.body { - let requestInfo = RequestInfo(data: String(buffer: body), - requestNumber: self.requestId, - connectionNumber: self.connectionID) - let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, - allocator: context.channel.allocator) + let requestInfo = RequestInfo( + data: String(buffer: body), + requestNumber: self.requestId, + connectionNumber: self.connectionID + ) + let responseBody = try! JSONEncoder().encodeAsByteBuffer( + requestInfo, + allocator: context.channel.allocator + ) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } else { - let requestInfo = RequestInfo(data: "", - requestNumber: self.requestId, - connectionNumber: self.connectionID) - let responseBody = try! JSONEncoder().encodeAsByteBuffer(requestInfo, - allocator: context.channel.allocator) + let requestInfo = RequestInfo( + data: "", + requestNumber: self.requestId, + connectionNumber: self.connectionID + ) + let responseBody = try! JSONEncoder().encodeAsByteBuffer( + requestInfo, + allocator: context.channel.allocator + ) context.write(wrapOutboundOut(.body(.byteBuffer(responseBody))), promise: nil) } - context.eventLoop.scheduleTask(in: self.delay) { + context.eventLoop.assumeIsolated().scheduleTask(in: self.delay) { guard context.channel.isActive else { context.close(promise: nil) return } - context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenComplete { result in + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).assumeIsolated().whenComplete { result in self.isServingRequest = false switch result { case .success: - if self.responseHeaders[canonicalForm: "X-Close-Connection"].contains("true") || - self.shouldClose { + if self.responseHeaders[canonicalForm: "X-Close-Connection"].contains("true") + || self.shouldClose + { context.close(promise: nil) } case .failure(let error): @@ -1029,7 +1161,7 @@ internal final class HTTPBinHandler: ChannelInboundHandler { } } -final class ConnectionsCountHandler: ChannelInboundHandler { +final class ConnectionsCountHandler: ChannelInboundHandler, Sendable { typealias InboundIn = Channel private let activeConns = ManagedAtomic(0) @@ -1048,8 +1180,8 @@ final class ConnectionsCountHandler: ChannelInboundHandler { _ = self.activeConns.loadThenWrappingIncrement(ordering: .relaxed) _ = self.createdConns.loadThenWrappingIncrement(ordering: .relaxed) - channel.closeFuture.whenComplete { _ in - _ = self.activeConns.loadThenWrappingDecrement(ordering: .relaxed) + channel.closeFuture.whenComplete { [activeConns] _ in + _ = activeConns.loadThenWrappingDecrement(ordering: .relaxed) } context.fireChannelRead(data) @@ -1069,7 +1201,7 @@ internal final class CloseWithoutClosingServerHandler: ChannelInboundHandler { func handlerAdded(context: ChannelHandlerContext) { self.onClosePromise = context.eventLoop.makePromise() - self.onClosePromise!.futureResult.whenSuccess(self.callback!) + self.onClosePromise!.futureResult.assumeIsolated().whenSuccess(self.callback!) self.callback = nil } @@ -1131,7 +1263,7 @@ final class ExpectClosureServerHandler: ChannelInboundHandler { struct EventLoopFutureTimeoutError: Error {} -extension EventLoopFuture { +extension EventLoopFuture where Value: Sendable { func timeout(after failDelay: TimeAmount) -> EventLoopFuture { let promise = self.eventLoop.makePromise(of: Value.self) @@ -1157,30 +1289,33 @@ struct CollectEverythingLogHandler: LogHandler { var logLevel: Logger.Level = .info let logStore: LogStore - class LogStore { + final class LogStore: Sendable { struct Entry { var level: Logger.Level var message: String var metadata: [String: String] } - var lock = NIOLock() - var logs: [Entry] = [] + private let logs = NIOLockedValueBox<[Entry]>([]) var allEntries: [Entry] { get { - return self.lock.withLock { self.logs } + self.logs.withLockedValue { $0 } } set { - self.lock.withLock { self.logs = newValue } + self.logs.withLockedValue { $0 = newValue } } } func append(level: Logger.Level, message: Logger.Message, metadata: Logger.Metadata?) { - self.lock.withLock { - self.logs.append(Entry(level: level, - message: message.description, - metadata: metadata?.mapValues { $0.description } ?? [:])) + self.logs.withLockedValue { + $0.append( + Entry( + level: level, + message: message.description, + metadata: metadata?.mapValues { $0.description } ?? [:] + ) + ) } } } @@ -1189,16 +1324,21 @@ struct CollectEverythingLogHandler: LogHandler { self.logStore = logStore } - func log(level: Logger.Level, - message: Logger.Message, - metadata: Logger.Metadata?, - file: String, function: String, line: UInt) { + func log( + level: Logger.Level, + message: Logger.Message, + metadata: Logger.Metadata?, + source: String, + file: String, + function: String, + line: UInt + ) { self.logStore.append(level: level, message: message, metadata: self.metadata.merging(metadata ?? [:]) { $1 }) } subscript(metadataKey key: String) -> Logger.Metadata.Value? { get { - return self.metadata[key] + self.metadata[key] } set { self.metadata[key] = newValue @@ -1210,10 +1350,10 @@ struct CollectEverythingLogHandler: LogHandler { /// consume the bytes by calling ``next()`` on the delegate. /// /// The sole purpose of this class is to enable straight-line stream tests. -class ResponseStreamDelegate: HTTPClientResponseDelegate { +final class ResponseStreamDelegate: HTTPClientResponseDelegate { typealias Response = Void - enum State { + enum State: Sendable { /// The delegate is in the idle state. There are no http response parts to be buffered /// and the consumer did not signal a demand. Transitions to all other states are allowed. case idle @@ -1231,10 +1371,11 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { } let eventLoop: EventLoop - private var state: State = .idle + private let state: NIOLoopBoundBox init(eventLoop: EventLoop) { self.eventLoop = eventLoop + self.state = .makeBoxSendingValue(.idle, eventLoop: eventLoop) } func next() -> EventLoopFuture { @@ -1248,25 +1389,25 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { } private func next0() -> EventLoopFuture { - switch self.state { + switch self.state.value { case .idle: let promise = self.eventLoop.makePromise(of: ByteBuffer?.self) - self.state = .waitingForBytes(promise) + self.state.value = .waitingForBytes(promise) return promise.futureResult case .buffering(let byteBuffer, done: false): - self.state = .idle + self.state.value = .idle return self.eventLoop.makeSucceededFuture(byteBuffer) case .buffering(let byteBuffer, done: true): - self.state = .finished + self.state.value = .finished return self.eventLoop.makeSucceededFuture(byteBuffer) case .waitingForBytes: preconditionFailure("Don't call `.next` twice") case .failed(let error): - self.state = .finished + self.state.value = .finished return self.eventLoop.makeFailedFuture(error) case .finished: @@ -1296,16 +1437,16 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { self.eventLoop.preconditionInEventLoop() - switch self.state { + switch self.state.value { case .idle: - self.state = .buffering(buffer, done: false) + self.state.value = .buffering(buffer, done: false) case .waitingForBytes(let promise): - self.state = .idle + self.state.value = .idle promise.succeed(buffer) case .buffering(var byteBuffer, done: false): var buffer = buffer byteBuffer.writeBuffer(&buffer) - self.state = .buffering(byteBuffer, done: false) + self.state.value = .buffering(byteBuffer, done: false) case .buffering(_, done: true), .finished, .failed: preconditionFailure("Invalid state: \(self.state)") } @@ -1316,14 +1457,14 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { func didReceiveError(task: HTTPClient.Task, _ error: Error) { self.eventLoop.preconditionInEventLoop() - switch self.state { + switch self.state.value { case .idle: - self.state = .failed(error) + self.state.value = .failed(error) case .waitingForBytes(let promise): - self.state = .finished + self.state.value = .finished promise.fail(error) case .buffering(_, done: false): - self.state = .failed(error) + self.state.value = .failed(error) case .buffering(_, done: true), .finished, .failed: preconditionFailure("Invalid state: \(self.state)") } @@ -1332,14 +1473,14 @@ class ResponseStreamDelegate: HTTPClientResponseDelegate { func didFinishRequest(task: HTTPClient.Task) throws { self.eventLoop.preconditionInEventLoop() - switch self.state { + switch self.state.value { case .idle: - self.state = .finished + self.state.value = .finished case .waitingForBytes(let promise): - self.state = .finished + self.state.value = .finished promise.succeed(nil) case .buffering(let byteBuffer, done: false): - self.state = .buffering(byteBuffer, done: true) + self.state.value = .buffering(byteBuffer, done: true) case .buffering(_, done: true), .finished, .failed: preconditionFailure("Invalid state: \(self.state)") } @@ -1354,11 +1495,14 @@ class HTTPEchoHandler: ChannelInboundHandler { let request = self.unwrapInboundIn(data) switch request { case .head(let requestHead): - context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), + promise: nil + ) case .body(let bytes): context.writeAndFlush(self.wrapOutboundOut(.body(.byteBuffer(bytes))), promise: nil) case .end: - context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenSuccess { + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).assumeIsolated().whenSuccess { context.close(promise: nil) } } @@ -1373,11 +1517,14 @@ final class HTTPEchoHeaders: ChannelInboundHandler { let request = self.unwrapInboundIn(data) switch request { case .head(let requestHead): - context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok, headers: requestHead.headers))), + promise: nil + ) case .body: break case .end: - context.writeAndFlush(self.wrapOutboundOut(.end(nil))).whenSuccess { + context.writeAndFlush(self.wrapOutboundOut(.end(nil))).assumeIsolated().whenSuccess { context.close(promise: nil) } } @@ -1409,7 +1556,10 @@ final class HTTP200DelayedHandler: ChannelInboundHandler { self.pendingBodyParts = pendingBodyParts - 1 } else { self.pendingBodyParts = nil - context.writeAndFlush(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), promise: nil) + context.writeAndFlush( + self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } } @@ -1420,51 +1570,136 @@ final class HTTP200DelayedHandler: ChannelInboundHandler { } private let cert = """ ------BEGIN CERTIFICATE----- -MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1 -czAgFw0xODEwMzExNTU1MjJaGA8yMTE4MTAwNzE1NTUyMlowDTELMAkGA1UEBhMC -dXMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDiC+TGmbSP/nWWN1tj -yNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMisUdb -d3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZHud9 -+JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKzV3S8 -kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAVKcNR -9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO538ljg -dslnAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAFYhA7sw8odOsRO8/DUklBOjPnmn -a078oSumgPXXw6AgcoAJv/Qthjo6CCEtrjYfcA9jaBw9/Tii7mDmqDRS5c9ZPL8+ -NEPdHjFCFBOEvlL6uHOgw0Z9Wz+5yCXnJ8oNUEgc3H2NbbzJF6sMBXSPtFS2NOK8 -OsAI9OodMrDd6+lwljrmFoCCkJHDEfE637IcsbgFKkzhO/oNCRK6OrudG4teDahz -Au4LoEYwT730QKC/VQxxEVZobjn9/sTrq9CZlbPYHxX4fz6e00sX7H9i49vk9zQ5 -5qCm9ljhrQPSa42Q62PPE2BEEGSP2KBm0J+H3vlvCD6+SNc/nMZjrRmgjrI= ------END CERTIFICATE----- -""" + -----BEGIN CERTIFICATE----- + MIICmDCCAYACCQCPC8JDqMh1zzANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1 + czAgFw0xODEwMzExNTU1MjJaGA8yMTE4MTAwNzE1NTUyMlowDTELMAkGA1UEBhMC + dXMwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDiC+TGmbSP/nWWN1tj + yNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMisUdb + d3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZHud9 + +JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKzV3S8 + kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAVKcNR + 9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO538ljg + dslnAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAFYhA7sw8odOsRO8/DUklBOjPnmn + a078oSumgPXXw6AgcoAJv/Qthjo6CCEtrjYfcA9jaBw9/Tii7mDmqDRS5c9ZPL8+ + NEPdHjFCFBOEvlL6uHOgw0Z9Wz+5yCXnJ8oNUEgc3H2NbbzJF6sMBXSPtFS2NOK8 + OsAI9OodMrDd6+lwljrmFoCCkJHDEfE637IcsbgFKkzhO/oNCRK6OrudG4teDahz + Au4LoEYwT730QKC/VQxxEVZobjn9/sTrq9CZlbPYHxX4fz6e00sX7H9i49vk9zQ5 + 5qCm9ljhrQPSa42Q62PPE2BEEGSP2KBm0J+H3vlvCD6+SNc/nMZjrRmgjrI= + -----END CERTIFICATE----- + """ private let key = """ ------BEGIN PRIVATE KEY----- -MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDiC+TGmbSP/nWW -N1tjyNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMi -sUdbd3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZ -Hud9+JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKz -V3S8kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAV -KcNR9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO53 -8ljgdslnAgMBAAECggEBANZNWFNAnYJ2R5xmVuo/GxFk68Ujd4i4TZpPYbhkk+QG -g8I0w5htlEQQkVHfZx2CpTvq8feuAH/YhlA5qeD5WaPwq26q5qsmyV6tQGDgb9lO -w85l6ySZDbwdVOJe2il/MSB6MclSKvTGNm59chJnfHYsmvY3HHq4qsc2F+tRKYMW -pY75LgEbaTUV69J3cbC1wAeVjv0q/krND+YkhYpTxNZhbazK/FHOCvY+zFu9fg0L -zpwbn5fb6wIvqG7tXp7koa3QMn64AXmO/fb5mBd8G2vBGYnxwb7Egwdg/3Dw+BXu -ynQLP7ixWsE2KNfR9Ce1i3YvEo6QDTv2340I3dntxkECgYEA9vdaL4PGyvEbpim4 -kqz1vuug8Iq0nTVDo6jmgH1o+XdcIbW3imXtgi5zUJpj4oDD7/4aufiJZjG64i/v -phe11xeUvh5QNNOzeMymVDoJut97F97KKKTv7bG8Rpon/WzH2I0SoAkECCwmdWAJ -H3nvOCnXEkpbCqmIUvHVURPRDn8CgYEA6lCk3EzFQlbXs3Sj5op61R3Mscx7/35A -eGv5axzbENHt1so+s3Zvyyi1bo4VBcwnKVCvQjmTuLiqrc9VfX8XdbiTUNnEr2u3 -992Ja6DEJTZ9gy5WiviwYnwU2HpjwOVNBb17T0NLoRHkDZ6iXj7NZgwizOki5p3j -/hS0pObSIRkCgYEAiEdOGNIarHoHy9VR6H5QzR2xHYssx2NRA8p8B4MsnhxjVqaz -tUcxnJiNQXkwjRiJBrGthdnD2ASxH4dcMsb6rMpyZcbMc5ouewZS8j9khx4zCqUB -4RPC4eMmBb+jOZEBZlnSYUUYWHokbrij0B61BsTvzUQCoQuUElEoaSkKP3kCgYEA -mwdqXHvK076jjo9w1drvtEu4IDc8H2oH++TsrEr2QiWzaDZ9z71f8BnqGNCW5jQS -AQrqOjXgIArGmqMgXB0Xh4LsrUS4Fpx9ptiD0JsYy8pGtuGUzvQFt9OC80ve7kSI -dnDMwj+zLUmqCrzXjuWcfpUu/UaPGeiDbZuDfcteYhkCgYBLyL5JY7Qd4gVQIhFX -7Sv3sNJN3KZCQHEzut7IwojaxgpuxiFvgsoXXuYolVCQp32oWbYcE2Yke+hOKsTE -sCMAWZiSGN2Nrfea730IYAXkUm8bpEd3VxDXEEv13nxVeQof+JGMdlkldFGaBRDU -oYQsPj00S3/GA9WDapwe81Wl2A== ------END PRIVATE KEY----- -""" + -----BEGIN PRIVATE KEY----- + MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDiC+TGmbSP/nWW + N1tjyNfnWCU5ATjtIOfdtP6ycx8JSeqkvyNXG21kNUn14jTTU8BglGL2hfVpCbMi + sUdbd3LpP8unSsvlOWwORFOViSy4YljSNM/FNoMtavuITA/sEELYgjWkz2o/uHPZ + Hud9+JQwGJgqIlMa3mr2IaaUZlWN3D1u88bzJYhpt3YyxRy9+OEoOKy36KdWwhKz + V3S8kXb0Y1GbAo68jJ9RfzeLy290mIs9qG2y1CNXWO6sxf6B//LaalizZiCfzYAV + KcNR9oNYsEJc5KB/+DsAGTzR7mL+oiU4h/vwVb2GTDat5C+PFGi6j1ujxYTRPO53 + 8ljgdslnAgMBAAECggEBANZNWFNAnYJ2R5xmVuo/GxFk68Ujd4i4TZpPYbhkk+QG + g8I0w5htlEQQkVHfZx2CpTvq8feuAH/YhlA5qeD5WaPwq26q5qsmyV6tQGDgb9lO + w85l6ySZDbwdVOJe2il/MSB6MclSKvTGNm59chJnfHYsmvY3HHq4qsc2F+tRKYMW + pY75LgEbaTUV69J3cbC1wAeVjv0q/krND+YkhYpTxNZhbazK/FHOCvY+zFu9fg0L + zpwbn5fb6wIvqG7tXp7koa3QMn64AXmO/fb5mBd8G2vBGYnxwb7Egwdg/3Dw+BXu + ynQLP7ixWsE2KNfR9Ce1i3YvEo6QDTv2340I3dntxkECgYEA9vdaL4PGyvEbpim4 + kqz1vuug8Iq0nTVDo6jmgH1o+XdcIbW3imXtgi5zUJpj4oDD7/4aufiJZjG64i/v + phe11xeUvh5QNNOzeMymVDoJut97F97KKKTv7bG8Rpon/WzH2I0SoAkECCwmdWAJ + H3nvOCnXEkpbCqmIUvHVURPRDn8CgYEA6lCk3EzFQlbXs3Sj5op61R3Mscx7/35A + eGv5axzbENHt1so+s3Zvyyi1bo4VBcwnKVCvQjmTuLiqrc9VfX8XdbiTUNnEr2u3 + 992Ja6DEJTZ9gy5WiviwYnwU2HpjwOVNBb17T0NLoRHkDZ6iXj7NZgwizOki5p3j + /hS0pObSIRkCgYEAiEdOGNIarHoHy9VR6H5QzR2xHYssx2NRA8p8B4MsnhxjVqaz + tUcxnJiNQXkwjRiJBrGthdnD2ASxH4dcMsb6rMpyZcbMc5ouewZS8j9khx4zCqUB + 4RPC4eMmBb+jOZEBZlnSYUUYWHokbrij0B61BsTvzUQCoQuUElEoaSkKP3kCgYEA + mwdqXHvK076jjo9w1drvtEu4IDc8H2oH++TsrEr2QiWzaDZ9z71f8BnqGNCW5jQS + AQrqOjXgIArGmqMgXB0Xh4LsrUS4Fpx9ptiD0JsYy8pGtuGUzvQFt9OC80ve7kSI + dnDMwj+zLUmqCrzXjuWcfpUu/UaPGeiDbZuDfcteYhkCgYBLyL5JY7Qd4gVQIhFX + 7Sv3sNJN3KZCQHEzut7IwojaxgpuxiFvgsoXXuYolVCQp32oWbYcE2Yke+hOKsTE + sCMAWZiSGN2Nrfea730IYAXkUm8bpEd3VxDXEEv13nxVeQof+JGMdlkldFGaBRDU + oYQsPj00S3/GA9WDapwe81Wl2A== + -----END PRIVATE KEY----- + """ + +final class BasicInboundTrafficShapingHandler: ChannelDuplexHandler { + typealias OutboundIn = ByteBuffer + typealias InboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + enum ReadState { + case flowingFreely + case pausing + case paused + + mutating func pause() { + switch self { + case .flowingFreely: + self = .pausing + case .pausing, .paused: + () // nothing to do + } + } + + mutating func unpause() -> Bool { + switch self { + case .flowingFreely: + return false // no extra `read` needed + case .pausing: + self = .flowingFreely + return false // no extra `read` needed + case .paused: + self = .flowingFreely + return true // yes, we need an extra read + } + } + + mutating func shouldRead() -> Bool { + switch self { + case .flowingFreely: + return true + case .pausing: + self = .paused + return false + case .paused: + return false + } + } + } + + private let targetBytesPerSecond: Int + private var currentSecondBytesSeen: Int = 0 + private var readState: ReadState = .flowingFreely + + init(targetBytesPerSecond: Int) { + self.targetBytesPerSecond = targetBytesPerSecond + } + + func evaluatePause(context: ChannelHandlerContext) { + if self.currentSecondBytesSeen >= self.targetBytesPerSecond { + self.readState.pause() + } else if self.currentSecondBytesSeen < self.targetBytesPerSecond { + if self.readState.unpause() { + context.read() + } + } + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let loopBoundContext = NIOLoopBound(context, eventLoop: context.eventLoop) + defer { + context.fireChannelRead(data) + } + let buffer = Self.unwrapInboundIn(data) + let byteCount = buffer.readableBytes + self.currentSecondBytesSeen += byteCount + context.eventLoop.assumeIsolated().scheduleTask(in: .seconds(1)) { + self.currentSecondBytesSeen -= byteCount + self.evaluatePause(context: loopBoundContext.value) + } + self.evaluatePause(context: context) + } + + func read(context: ChannelHandlerContext) { + if self.readState.shouldRead() { + context.read() + } + } +} diff --git a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift index 1bfca1d30..50c3ecb9d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientTests.swift @@ -12,17 +12,15 @@ // //===----------------------------------------------------------------------===// -/* NOT @testable */ import AsyncHTTPClient // Tests that need @testable go into HTTPClientInternalTests.swift +import AsyncHTTPClient // NOT @testable - tests that need @testable go into HTTPClientInternalTests.swift import Atomics -#if canImport(Network) -import Network -#endif import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded import NIOFoundationCompat import NIOHTTP1 +import NIOHTTP2 import NIOHTTPCompression import NIOPosix import NIOSSL @@ -30,6 +28,10 @@ import NIOTestUtils import NIOTransportServices import XCTest +#if canImport(Network) +import Network +#endif + final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testRequestURI() throws { let request1 = try Request(url: "https://someserver.com:8888/some/path?foo=bar") @@ -43,8 +45,12 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertEqual(request2.url.path, "") let request3 = try Request(url: "unix:///tmp/file") - XCTAssertNil(request3.url.host) XCTAssertEqual(request3.host, "") + #if os(Linux) && compiler(>=6.0) && compiler(<6.1) + XCTAssertEqual(request3.url.host, "") + #else + XCTAssertNil(request3.url.host) + #endif XCTAssertEqual(request3.url.path, "/tmp/file") XCTAssertEqual(request3.port, 80) XCTAssertFalse(request3.useTLS) @@ -118,7 +124,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertEqual(url.scheme, "http+unix") XCTAssertEqual(url.host, "/tmp/file with spacesと漢字") XCTAssertEqual(url.path, "/file/path") - XCTAssertEqual(url.absoluteString, "http+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path") + XCTAssertEqual( + url.absoluteString, + "http+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path" + ) } let url5 = URL(httpsURLWithSocketPath: "/tmp/file") @@ -154,7 +163,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertEqual(url.scheme, "https+unix") XCTAssertEqual(url.host, "/tmp/file with spacesと漢字") XCTAssertEqual(url.path, "/file/path") - XCTAssertEqual(url.absoluteString, "https+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path") + XCTAssertEqual( + url.absoluteString, + "https+unix://%2Ftmp%2Ffile%20with%20spaces%E3%81%A8%E6%BC%A2%E5%AD%97/file/path" + ) } } @@ -167,55 +179,116 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testConvenienceExecuteMethods() throws { - XCTAssertEqual(["GET"[...]], - try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["POST"[...]], - try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["PATCH"[...]], - try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["PUT"[...]], - try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["DELETE"[...]], - try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["GET"[...]], - try self.defaultClient.execute(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["CHECKOUT"[...]], - try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["POST"[...]], + try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["PATCH"[...]], + try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["PUT"[...]], + try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["DELETE"[...]], + try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.execute(url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["CHECKOUT"[...]], + try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) } func testConvenienceExecuteMethodsOverSocket() throws { - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(["GET"[...]], - try self.defaultClient.execute(socketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["GET"[...]], - try self.defaultClient.execute(.GET, socketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["POST"[...]], - try self.defaultClient.execute(.POST, socketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - }) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.execute(socketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["GET"[...]], + try self.defaultClient.execute(.GET, socketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["POST"[...]], + try self.defaultClient.execute(.POST, socketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + } + ) } func testConvenienceExecuteMethodsOverSecureSocket() throws { - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true, compress: false), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localSocketPathHTTPBin = HTTPBin( + .http1_1(ssl: true, compress: false), + bindTarget: .unixDomainSocket(path) + ) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(["GET"[...]], - try localClient.execute(secureSocketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["GET"[...]], - try localClient.execute(.GET, secureSocketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - XCTAssertEqual(["POST"[...]], - try localClient.execute(.POST, secureSocketPath: path, urlPath: "echo-method").wait().headers[canonicalForm: "X-Method-Used"]) - }) + XCTAssertEqual( + ["GET"[...]], + try localClient.execute(secureSocketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["GET"[...]], + try localClient.execute(.GET, secureSocketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + XCTAssertEqual( + ["POST"[...]], + try localClient.execute(.POST, secureSocketPath: path, urlPath: "echo-method").wait().headers[ + canonicalForm: "X-Method-Used" + ] + ) + } + ) } func testGet() throws { @@ -231,7 +304,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testPost() throws { - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .string("1234")).wait() + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .string("1234")) + .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -240,10 +314,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testPostWithGenericBody() throws { - let bodyData = Array("hello, world!").lazy.map { $0.uppercased().first!.asciiValue! } - let erasedData = AnyRandomAccessCollection(bodyData) + let bodyData = Array(Array("hello, world!").lazy.map { $0.uppercased().first!.asciiValue! }) - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .bytes(erasedData)).wait() + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .bytes(bodyData)) + .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -254,7 +328,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testPostWithFoundationDataBody() throws { let bodyData = Data("hello, world!".utf8) - let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .data(bodyData)).wait() + let response = try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + "post", body: .data(bodyData)) + .wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } let data = try JSONDecoder().decode(RequestInfo.self, from: bytes!) @@ -264,8 +339,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testGetHttps() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -277,8 +354,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testGetHttpsWithIP() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -296,8 +375,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try group.syncShutdownGracefully()) } let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(group), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(group), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -310,8 +391,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testGetHttpsWithIPv6() throws { try XCTSkipUnless(canBindIPv6Loopback, "Requires IPv6") let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .localhostIPv6RandomPort) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -330,8 +413,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try group.syncShutdownGracefully()) } let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .localhostIPv6RandomPort) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(group), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(group), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -343,14 +428,20 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testPostHttps() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) } - let request = try Request(url: "https://localhost:\(localHTTPBin.port)/post", method: .POST, body: .string("1234")) + let request = try Request( + url: "https://localhost:\(localHTTPBin.port)/post", + method: .POST, + body: .string("1234") + ) let response = try localClient.execute(request: request).wait() let bytes = response.body.flatMap { $0.getData(at: 0, length: $0.readableBytes) } @@ -362,8 +453,13 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testHttpRedirect() throws { let httpsBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 10, allowCycles: true) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -372,103 +468,246 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/302").wait() XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, self.defaultHTTPBinURLPrefix + "ok") + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [ + self.defaultHTTPBinURLPrefix + "redirect/302", + self.defaultHTTPBinURLPrefix + "ok", + ] + ) - response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/https?port=\(httpsBin.port)").wait() + response = try localClient.get(url: self.defaultHTTPBinURLPrefix + "redirect/https?port=\(httpsBin.port)") + .wait() XCTAssertEqual(response.status, .ok) - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpSocketPath in - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpsSocketPath in - let socketHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(httpSocketPath)) - let socketHTTPSBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(httpsSocketPath)) - defer { - XCTAssertNoThrow(try socketHTTPBin.shutdown()) - XCTAssertNoThrow(try socketHTTPSBin.shutdown()) - } - - // From HTTP or HTTPS to HTTP+UNIX should fail to redirect - var targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - var request = try Request(url: self.defaultHTTPBinURLPrefix + "redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - var response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - request = try Request(url: "https://localhost:\(httpsBin.port)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - // From HTTP or HTTPS to HTTPS+UNIX should also fail to redirect - targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: self.defaultHTTPBinURLPrefix + "redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - request = try Request(url: "https://localhost:\(httpsBin.port)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .found) - XCTAssertEqual(response.headers.first(name: "Location"), targetURL) - - // ... while HTTP+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed - targetURL = self.defaultHTTPBinURLPrefix + "ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https://localhost:\(httpsBin.port)/ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - // ... and HTTPS+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed - targetURL = self.defaultHTTPBinURLPrefix + "ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https://localhost:\(httpsBin.port)/ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) - - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - - targetURL = "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" - request = try Request(url: "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", method: .GET, headers: ["X-Target-Redirect-URL": targetURL], body: nil) + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpSocketPath in + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { httpsSocketPath in + let socketHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(httpSocketPath)) + let socketHTTPSBin = HTTPBin( + .http1_1(ssl: true), + bindTarget: .unixDomainSocket(httpsSocketPath) + ) + defer { + XCTAssertNoThrow(try socketHTTPBin.shutdown()) + XCTAssertNoThrow(try socketHTTPSBin.shutdown()) + } - response = try localClient.execute(request: request).wait() - XCTAssertEqual(response.status, .ok) - }) - }) + // From HTTP or HTTPS to HTTP+UNIX should fail to redirect + var targetURL = + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + var request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + var response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) + + request = try Request( + url: "https://localhost:\(httpsBin.port)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) + + // From HTTP or HTTPS to HTTPS+UNIX should also fail to redirect + targetURL = + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) + + request = try Request( + url: "https://localhost:\(httpsBin.port)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .found) + XCTAssertEqual(response.headers.first(name: "Location"), targetURL) + XCTAssertEqual(response.url, request.url) + XCTAssertEqual(response.history.map(\.request.url), [request.url]) + + // ... while HTTP+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed + targetURL = self.defaultHTTPBinURLPrefix + "ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = "https://localhost:\(httpsBin.port)/ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + // ... and HTTPS+UNIX to HTTP, HTTPS, or HTTP(S)+UNIX should succeed + targetURL = self.defaultHTTPBinURLPrefix + "ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = "https://localhost:\(httpsBin.port)/ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = + "http+unix://\(httpSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + + targetURL = + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/ok" + request = try Request( + url: + "https+unix://\(httpsSocketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed)!)/redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": targetURL], + body: nil + ) + + response = try localClient.execute(request: request).wait() + XCTAssertEqual(response.status, .ok) + XCTAssertEqual(response.url?.absoluteString, targetURL) + XCTAssertEqual( + response.history.map(\.request.url.absoluteString), + [request.url.absoluteString, targetURL] + ) + } + ) + } + ) } func testHttpHostRedirect() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 10, allowCycles: true) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -495,12 +734,37 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertEqual(.ok, response.status) } + func testLeadingSlashRelativeURL() throws { + let noLeadingSlashURL = URL( + string: "percent%2Fencoded/hello", + relativeTo: URL(string: self.defaultHTTPBinURLPrefix)! + )! + let withLeadingSlashURL = URL( + string: "/percent%2Fencoded/hello", + relativeTo: URL(string: self.defaultHTTPBinURLPrefix)! + )! + + let noLeadingSlashURLRequest = try HTTPClient.Request(url: noLeadingSlashURL, method: .GET) + let withLeadingSlashURLRequest = try HTTPClient.Request(url: withLeadingSlashURL, method: .GET) + + let noLeadingSlashURLResponse = try self.defaultClient.execute(request: noLeadingSlashURLRequest).wait() + let withLeadingSlashURLResponse = try self.defaultClient.execute(request: withLeadingSlashURLRequest).wait() + + XCTAssertEqual(noLeadingSlashURLResponse.status, .ok) + XCTAssertEqual(withLeadingSlashURLResponse.status, .ok) + } + func testMultipleContentLengthHeaders() throws { let body = ByteBuffer(string: "hello world!") var headers = HTTPHeaders() headers.add(name: "Content-Length", value: "12") - let request = try Request(url: self.defaultHTTPBinURLPrefix + "post", method: .POST, headers: headers, body: .byteBuffer(body)) + let request = try Request( + url: self.defaultHTTPBinURLPrefix + "post", + method: .POST, + headers: headers, + body: .byteBuffer(body) + ) let response = try self.defaultClient.execute(request: request).wait() // if the library adds another content length header we'll get a bad request error. XCTAssertEqual(.ok, response.status) @@ -520,11 +784,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var request = try Request(url: self.defaultHTTPBinURLPrefix + "events/10/content-length") request.headers.add(name: "Accept", value: "text/event-stream") - let progress = - try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in + let response = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Response in let delegate = try FileDownloadDelegate(path: path) - let progress = try self.defaultClient.execute( + let response = try self.defaultClient.execute( request: request, delegate: delegate ) @@ -532,24 +796,30 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { try XCTAssertEqual(50, TemporaryFileHelpers.fileSize(path: path)) - return progress + return response } - XCTAssertEqual(50, progress.totalBytes) - XCTAssertEqual(50, progress.receivedBytes) + XCTAssertEqual(.ok, response.head.status) + XCTAssertEqual("50", response.head.headers.first(name: "content-length")) + + XCTAssertEqual(50, response.totalBytes) + XCTAssertEqual(50, response.receivedBytes) } func testFileDownloadError() throws { var request = try Request(url: self.defaultHTTPBinURLPrefix + "not-found") request.headers.add(name: "Accept", value: "text/event-stream") - let progress = - try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in - let delegate = try FileDownloadDelegate(path: path, reportHead: { - XCTAssertEqual($0.status, .notFound) - }) + let response = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Response in + let delegate = try FileDownloadDelegate( + path: path, + reportHead: { + XCTAssertEqual($0.status, .notFound) + } + ) - let progress = try self.defaultClient.execute( + let response = try self.defaultClient.execute( request: request, delegate: delegate ) @@ -557,11 +827,14 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertFalse(TemporaryFileHelpers.fileExists(path: path)) - return progress + return response } - XCTAssertEqual(nil, progress.totalBytes) - XCTAssertEqual(0, progress.receivedBytes) + XCTAssertEqual(.notFound, response.head.status) + XCTAssertFalse(response.head.headers.contains(name: "content-length")) + + XCTAssertEqual(nil, response.totalBytes) + XCTAssertEqual(0, response.receivedBytes) } func testFileDownloadCustomError() throws { @@ -569,12 +842,16 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { struct CustomError: Equatable, Error {} try TemporaryFileHelpers.withTemporaryFilePath { path in - let delegate = try FileDownloadDelegate(path: path, reportHead: { task, head in - XCTAssertEqual(head.status, .ok) - task.fail(reason: CustomError()) - }, reportProgress: { _, _ in - XCTFail("should never be called") - }) + let delegate = try FileDownloadDelegate( + path: path, + reportHead: { task, head in + XCTAssertEqual(head.status, .ok) + task.fail(reason: CustomError()) + }, + reportProgress: { _, _ in + XCTFail("should never be called") + } + ) XCTAssertThrowsError( try self.defaultClient.execute( request: request, @@ -596,8 +873,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testReadTimeout() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(read: .milliseconds(150)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(read: .milliseconds(150))) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -609,8 +888,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testWriteTimeout() throws { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(write: .nanoseconds(10)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(timeout: HTTPClient.Configuration.Timeout(write: .nanoseconds(10))) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -618,19 +899,21 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // Create a request that writes a chunk, then waits longer than the configured write timeout, // and then writes again. This should trigger a write timeout error. - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "post", - method: .POST, - headers: ["transfer-encoding": "chunked"], - body: .stream { streamWriter in - _ = streamWriter.write(.byteBuffer(.init())) - - let promise = self.clientGroup.next().makePromise(of: Void.self) - self.clientGroup.next().scheduleTask(in: .milliseconds(3)) { - streamWriter.write(.byteBuffer(.init())).cascade(to: promise) - } + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "post", + method: .POST, + headers: ["transfer-encoding": "chunked"], + body: .stream { streamWriter in + _ = streamWriter.write(.byteBuffer(.init())) + + let promise = localClient.eventLoopGroup.next().makePromise(of: Void.self) + localClient.eventLoopGroup.next().scheduleTask(in: .milliseconds(3)) { + streamWriter.write(.byteBuffer(.init())).cascade(to: promise) + } - return promise.futureResult - }) + return promise.futureResult + } + ) XCTAssertThrowsError(try localClient.execute(request: request).wait()) { XCTAssertEqual($0 as? HTTPClientError, .writeTimeout) @@ -666,8 +949,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let url = "http://localhost:\(port)/get" #endif - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150)))) + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(100), read: .milliseconds(150))) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown()) @@ -679,7 +964,12 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testDeadline() { - XCTAssertThrowsError(try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "wait", deadline: .now() + .milliseconds(150)).wait()) { + XCTAssertThrowsError( + try self.defaultClient.get( + url: self.defaultHTTPBinURLPrefix + "wait", + deadline: .now() + .milliseconds(150) + ).wait() + ) { XCTAssertEqual($0 as? HTTPClientError, .deadlineExceeded) } } @@ -765,7 +1055,13 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let localHTTPBin = HTTPBin(proxy: .simulate(authorization: "Basic YWxhZGRpbjpvcGVuc2VzYW1l")) let localClient = HTTPClient( eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(proxy: .server(host: "localhost", port: localHTTPBin.port, authorization: .basic(username: "aladdin", password: "opensesame"))) + configuration: .init( + proxy: .server( + host: "localhost", + port: localHTTPBin.port, + authorization: .basic(username: "aladdin", password: "opensesame") + ) + ) ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -802,7 +1098,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testUploadStreaming() throws { - let body: HTTPClient.Body = .stream(length: 8) { writer in + let body: HTTPClient.Body = .stream(contentLength: 8) { writer in let buffer = ByteBuffer(string: "1234") return writer.write(.byteBuffer(buffer)).flatMap { let buffer = ByteBuffer(string: "4321") @@ -819,48 +1115,57 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testEventLoopArgument() throws { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(redirectConfiguration: .follow(max: 10, allowCycles: true)) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } - class EventLoopValidatingDelegate: HTTPClientResponseDelegate { + final class EventLoopValidatingDelegate: HTTPClientResponseDelegate { typealias Response = Bool let eventLoop: EventLoop - var result = false + let result = NIOLockedValueBox(false) init(eventLoop: EventLoop) { self.eventLoop = eventLoop } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.result = task.eventLoop === self.eventLoop + self.result.withLockedValue { $0 = task.eventLoop === self.eventLoop } return task.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws -> Bool { - return self.result + self.result.withLockedValue { $0 } } } let eventLoop = self.clientGroup.next() let delegate = EventLoopValidatingDelegate(eventLoop: eventLoop) var request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get") - var response = try localClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait() + var response = try localClient.execute( + request: request, + delegate: delegate, + eventLoop: .delegate(on: eventLoop) + ).wait() XCTAssertEqual(true, response) // redirect request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "redirect/302") - response = try localClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)).wait() + response = try localClient.execute(request: request, delegate: delegate, eventLoop: .delegate(on: eventLoop)) + .wait() XCTAssertEqual(true, response) } func testDecompression() throws { let localHTTPBin = HTTPBin(.http1_1(compress: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(decompression: .enabled(limit: .none))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(decompression: .enabled(limit: .none)) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -869,7 +1174,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var body = "" for _ in 1...1000 { - body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + body += + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." } for algorithm in [nil, "gzip", "deflate"] { @@ -911,7 +1217,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var body = "" for _ in 1...1000 { - body += "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." + body += + "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." } for algorithm: String? in [nil] { @@ -939,7 +1246,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testDecompressionLimit() throws { let localHTTPBin = HTTPBin(.http1_1(compress: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(decompression: .enabled(limit: .ratio(1)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(decompression: .enabled(limit: .ratio(1))) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) @@ -957,30 +1267,47 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testLoopDetectionRedirectLimit() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 5, allowCycles: false))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 5, allowCycles: false) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) } - XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").wait(), "Should fail with redirect limit") { error in + XCTAssertThrowsError( + try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").wait(), + "Should fail with redirect limit" + ) { error in XCTAssertEqual(error as? HTTPClientError, HTTPClientError.redirectCycleDetected) } } func testCountRedirectLimit() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none, redirectConfiguration: .follow(max: 10, allowCycles: true))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration( + certificateVerification: .none, + redirectConfiguration: .follow(max: 10, allowCycles: true) + ) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) } - XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").timeout(after: .seconds(10)).wait()) { error in + XCTAssertThrowsError( + try localClient.get(url: "https://localhost:\(localHTTPBin.port)/redirect/infinite1").timeout( + after: .seconds(10) + ).wait() + ) { error in XCTAssertEqual(error as? HTTPClientError, HTTPClientError.redirectLimitReached) } } @@ -998,13 +1325,15 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { defer { XCTAssertNoThrow(try localClient.syncShutdown()) } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "https://localhost:\(localHTTPBin.port)/redirect/target", - method: .GET, - headers: [ - "X-Target-Redirect-URL": "/redirect/target", - ] - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/redirect/target", + method: .GET, + headers: [ + "X-Target-Redirect-URL": "/redirect/target" + ] + ) + ) guard let request = maybeRequest else { return } XCTAssertThrowsError( @@ -1018,14 +1347,18 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let numberOfRequestsPerThread = 1000 let numberOfParallelWorkers = 5 - final class HTTPServer: ChannelInboundHandler { + final class HTTPServer: ChannelInboundHandler, Sendable { typealias InboundIn = HTTPServerRequestPart typealias OutboundOut = HTTPServerResponsePart func channelRead(context: ChannelHandlerContext, data: NIOAny) { if case .end = self.unwrapInboundIn(data) { - let responseHead = HTTPServerResponsePart.head(.init(version: .init(major: 1, minor: 1), - status: .ok)) + let responseHead = HTTPServerResponsePart.head( + .init( + version: .init(major: 1, minor: 1), + status: .ok + ) + ) context.write(self.wrapOutboundOut(responseHead), promise: nil) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } @@ -1038,28 +1371,33 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) - .serverChannelOption(ChannelOptions.backlog, value: .init(numberOfParallelWorkers)) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline(withPipeliningAssistance: false, - withServerUpgrade: nil, - withErrorHandling: false).flatMap { - channel.pipeline.addHandler(HTTPServer()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) + .serverChannelOption(ChannelOptions.backlog, value: .init(numberOfParallelWorkers)) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline( + withPipeliningAssistance: false, + withServerUpgrade: nil, + withErrorHandling: false + ).flatMap { + channel.pipeline.addHandler(HTTPServer()) + } } - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } let url = "http://127.0.0.1:\(server?.localAddress?.port ?? -1)/hello" let g = DispatchGroup() + let defaultClient = self.defaultClient! for workerID in 0...whenAllComplete(tasks.map { $0.futureResult }, on: localClient.eventLoopGroup.next()).wait() + let results = try EventLoopFuture.whenAllComplete( + tasks.map { $0.futureResult }, + on: localClient.eventLoopGroup.next() + ).wait() for result in results { switch result { @@ -1241,9 +1633,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // We're speaking TLS to a plain text server. This will cause the handshake to fail but given // that the bytes "HTTP/1.1" aren't the start of a valid TLS packet, we can also get // errSSLPeerProtocolVersion because the first bytes contain the version. - XCTAssert(clientError.status == errSSLHandshakeFail || - clientError.status == errSSLPeerProtocolVersion, - "unexpected NWTLSError with status \(clientError.status)") + XCTAssert( + clientError.status == errSSLHandshakeFail || clientError.status == errSSLPeerProtocolVersion, + "unexpected NWTLSError with status \(clientError.status)" + ) #endif } else { guard let clientError = error as? NIOSSLError, case NIOSSLError.handshakeFailed = clientError else { @@ -1260,15 +1653,18 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) let configuration = try TLSConfiguration.makeServerConfiguration( certificateChain: NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, - privateKey: .file(keyPath) + privateKey: .privateKey(key) ) let sslContext = try NIOSSLContext(configuration: configuration) let server = ServerBootstrap(group: serverGroup) .childChannelInitializer { channel in - channel.pipeline.addHandler(NIOSSLServerHandler(context: sslContext)) + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } } let serverChannel = try server.bind(host: "localhost", port: 0).wait() defer { XCTAssertNoThrow(try serverChannel.close().wait()) } @@ -1287,7 +1683,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") #else guard let sslError = error as? NIOSSLError, - case .handshakeFailed(.sslError) = sslError else { + case .handshakeFailed(.sslError) = sslError + else { XCTFail("unexpected error \(error)") return } @@ -1300,15 +1697,18 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { /// openssl req -x509 -newkey rsa:4096 -keyout self_signed_key.pem -out self_signed_cert.pem -sha256 -days 99999 -nodes -subj '/CN=localhost' let certPath = Bundle.module.path(forResource: "self_signed_cert", ofType: "pem")! let keyPath = Bundle.module.path(forResource: "self_signed_key", ofType: "pem")! + let key = try NIOSSLPrivateKey(file: keyPath, format: .pem) let configuration = try TLSConfiguration.makeServerConfiguration( certificateChain: NIOSSLCertificate.fromPEMFile(certPath).map { .certificate($0) }, - privateKey: .file(keyPath) + privateKey: .privateKey(key) ) let sslContext = try NIOSSLContext(configuration: configuration) let server = ServerBootstrap(group: serverGroup) .childChannelInitializer { channel in - channel.pipeline.addHandler(NIOSSLServerHandler(context: sslContext)) + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(NIOSSLServerHandler(context: sslContext)) + } } let serverChannel = try server.bind(host: "localhost", port: 0).wait() defer { XCTAssertNoThrow(try serverChannel.close().wait()) } @@ -1319,7 +1719,9 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: config) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } - XCTAssertThrowsError(try localClient.get(url: "https://localhost:\(port)", deadline: .now() + .seconds(2)).wait()) { error in + XCTAssertThrowsError( + try localClient.get(url: "https://localhost:\(port)", deadline: .now() + .seconds(2)).wait() + ) { error in #if canImport(Network) guard let nwTLSError = error as? HTTPClient.NWTLSError else { XCTFail("could not cast \(error) of type \(type(of: error)) to \(HTTPClient.NWTLSError.self)") @@ -1328,7 +1730,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertEqual(nwTLSError.status, errSSLBadCert, "unexpected tls error: \(nwTLSError)") #else guard let sslError = error as? NIOSSLError, - case .handshakeFailed(.sslError) = sslError else { + case .handshakeFailed(.sslError) = sslError + else { XCTFail("unexpected error \(error)") return } @@ -1359,13 +1762,17 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let requestCount = 200 var futureResults = [EventLoopFuture]() for _ in 1...requestCount { - let req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", - method: .GET, - headers: ["X-internal-delay": "5", "Connection": "close"]) + let req = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .GET, + headers: ["X-internal-delay": "5", "Connection": "close"] + ) futureResults.append(self.defaultClient.execute(request: req)) } - XCTAssertNoThrow(try EventLoopFuture.andAllComplete(futureResults, on: eventLoop) - .timeout(after: .seconds(10)).wait()) + XCTAssertNoThrow( + try EventLoopFuture.andAllComplete(futureResults, on: eventLoop) + .timeout(after: .seconds(10)).wait() + ) } func testManyConcurrentRequestsWork() { @@ -1380,13 +1787,14 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { for w in 0..]() for i in 1...100 { - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .GET, headers: ["X-internal-delay": "10"]) + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .GET, + headers: ["X-internal-delay": "10"] + ) let preference: HTTPClient.EventLoopPreference if i <= 50 { preference = .delegateAndChannel(on: first) @@ -1620,15 +2054,18 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let seenError = DispatchGroup() seenError.enter() var maybeSecondRequest: EventLoopFuture? - XCTAssertNoThrow(maybeSecondRequest = try el.submit { - let neverSucceedingRequest = localClient.get(url: url) - let secondRequest = neverSucceedingRequest.flatMapError { error in - XCTAssertEqual(.cancelled, error as? HTTPClientError) - seenError.leave() - return localClient.get(url: url) // <== this is the main part, during the error callout, we call back in - } - return secondRequest - }.wait()) + XCTAssertNoThrow( + maybeSecondRequest = try el.submit { + let neverSucceedingRequest = localClient.get(url: url) + let secondRequest = neverSucceedingRequest.flatMapError { error in + XCTAssertEqual(.cancelled, error as? HTTPClientError) + seenError.leave() + // v this is the main part, during the error callout, we call back in + return localClient.get(url: url) + } + return secondRequest + }.wait() + ) guard let secondRequest = maybeSecondRequest else { XCTFail("couldn't get request future") @@ -1654,13 +2091,15 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try localClient.syncShutdown()) } - XCTAssertEqual(.ok, - try el.flatSubmit { () -> EventLoopFuture in - localClient.get(url: url).flatMap { firstResponse in - XCTAssertEqual(.ok, firstResponse.status) - return localClient.get(url: url) // <== interesting bit here - } - }.wait().status) + XCTAssertEqual( + .ok, + try el.flatSubmit { () -> EventLoopFuture in + localClient.get(url: url).flatMap { firstResponse in + XCTAssertEqual(.ok, firstResponse.status) + return localClient.get(url: url) // <== interesting bit here + } + }.wait().status + ) } func testMakeSecondRequestWhilstFirstIsOngoing() { @@ -1677,11 +2116,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let url = "http://127.0.0.1:\(web.serverPort)" let firstRequest = client.get(url: url) - XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head + XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head // Now, the first request is ongoing but not complete, let's start a second one let secondRequest = client.get(url: url) - XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end + XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end XCTAssertNoThrow(try web.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) XCTAssertNoThrow(try web.writeOutbound(.end(nil))) @@ -1689,8 +2128,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertEqual(.ok, try firstRequest.wait().status) // Okay, first request done successfully, let's do the second one too. - XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head - XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end + XCTAssertNoThrow(XCTAssertNotNil(try web.readInbound())) // first request: .head + XCTAssertEqual(.end(nil), try web.readInbound()) // first request: .end XCTAssertNoThrow(try web.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .created)))) XCTAssertNoThrow(try web.writeOutbound(.end(nil))) @@ -1701,15 +2140,19 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // This tests just connecting to a URL where the whole URL is the UNIX domain socket path like // unix:///this/is/my/socket.sock // We don't really have a path component, so we'll have to use "/" - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localHTTPBin.shutdown()) + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + let target = "unix://\(path)" + XCTAssertEqual( + ["Yes"[...]], + try self.defaultClient.get(url: target).wait().headers[canonicalForm: "X-Is-This-Slash"] + ) } - let target = "unix://\(path)" - XCTAssertEqual(["Yes"[...]], - try self.defaultClient.get(url: target).wait().headers[canonicalForm: "X-Is-This-Slash"]) - }) + ) } func testUDSSocketAndPath() { @@ -1717,56 +2160,73 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // // 1. a "base path" which is the path to the UNIX domain socket // 2. an actual path which is the normal path in a regular URL like https://example.com/this/is/the/path - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } - guard let target = URL(string: "/echo-uri", relativeTo: URL(string: "unix://\(path)")), - let request = try? Request(url: target) else { - XCTFail("couldn't build URL for request") - return + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(string: "/echo-uri", relativeTo: URL(string: "unix://\(path)")), + let request = try? Request(url: target) + else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertEqual( + ["/echo-uri"[...]], + try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"] + ) } - XCTAssertEqual(["/echo-uri"[...]], - try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"]) - }) + ) } func testHTTPPlusUNIX() { // Here, we're testing a URL where the UNIX domain socket is encoded as the host name - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - defer { - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } - guard let target = URL(httpURLWithSocketPath: path, uri: "/echo-uri"), - let request = try? Request(url: target) else { - XCTFail("couldn't build URL for request") - return + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + defer { + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(httpURLWithSocketPath: path, uri: "/echo-uri"), + let request = try? Request(url: target) + else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertEqual( + ["/echo-uri"[...]], + try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"] + ) } - XCTAssertEqual(["/echo-uri"[...]], - try self.defaultClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"]) - }) + ) } func testHTTPSPlusUNIX() { // Here, we're testing a URL where the UNIX domain socket is encoded as the host name - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localHTTPBin.shutdown()) - } - guard let target = URL(httpsURLWithSocketPath: path, uri: "/echo-uri"), - let request = try? Request(url: target) else { - XCTFail("couldn't build URL for request") - return + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let localHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localHTTPBin.shutdown()) + } + guard let target = URL(httpsURLWithSocketPath: path, uri: "/echo-uri"), + let request = try? Request(url: target) + else { + XCTFail("couldn't build URL for request") + return + } + XCTAssertEqual( + ["/echo-uri"[...]], + try localClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"] + ) } - XCTAssertEqual(["/echo-uri"[...]], - try localClient.execute(request: request).wait().headers[canonicalForm: "X-Calling-URI"]) - }) + ) } func testUseExistingConnectionOnDifferentEL() throws { @@ -1780,20 +2240,27 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let eventLoops = (1...threadCount).map { _ in elg.next() } let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get") - let closingRequest = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", headers: ["Connection": "close"]) + let closingRequest = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + headers: ["Connection": "close"] + ) for (index, el) in eventLoops.enumerated() { if index.isMultiple(of: 2) { - XCTAssertNoThrow(try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait()) + XCTAssertNoThrow( + try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait() + ) } else { - XCTAssertNoThrow(try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait()) + XCTAssertNoThrow( + try localClient.execute(request: request, eventLoop: .delegateAndChannel(on: el)).wait() + ) XCTAssertNoThrow(try localClient.execute(request: closingRequest, eventLoop: .indifferent).wait()) } } } func testWeRecoverFromServerThatClosesTheConnectionOnUs() { - final class ServerThatAcceptsThenRejects: ChannelInboundHandler { + final class ServerThatAcceptsThenRejects: ChannelInboundHandler, Sendable { typealias InboundIn = HTTPServerRequestPart typealias OutboundOut = HTTPServerResponsePart @@ -1819,8 +2286,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let last = self.requestNumber.loadThenWrappingIncrement(ordering: .relaxed) switch last { case 0, 2: - context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), - promise: nil) + context.write( + self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) case 1: context.close(promise: nil) @@ -1833,20 +2302,24 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let requestNumber = ManagedAtomic(0) let connectionNumber = ManagedAtomic(0) - let sharedStateServerHandler = ServerThatAcceptsThenRejects(requestNumber: requestNumber, - connectionNumber: connectionNumber) + let sharedStateServerHandler = ServerThatAcceptsThenRejects( + requestNumber: requestNumber, + connectionNumber: connectionNumber + ) var maybeServer: Channel? - XCTAssertNoThrow(maybeServer = try ServerBootstrap(group: self.serverGroup) - .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) - .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline().flatMap { - // We're deliberately adding a handler which is shared between multiple channels. This is normally - // very verboten but this handler is specially crafted to tolerate this. - channel.pipeline.addHandler(sharedStateServerHandler) + XCTAssertNoThrow( + maybeServer = try ServerBootstrap(group: self.serverGroup) + .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) + .childChannelInitializer { channel in + channel.pipeline.configureHTTPServerPipeline().flatMap { + // We're deliberately adding a handler which is shared between multiple channels. This is normally + // very verboten but this handler is specially crafted to tolerate this. + channel.pipeline.addHandler(sharedStateServerHandler) + } } - } - .bind(host: "127.0.0.1", port: 0) - .wait()) + .bind(host: "127.0.0.1", port: 0) + .wait() + ) guard let server = maybeServer else { XCTFail("couldn't create server") return @@ -1882,8 +2355,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { maximumAllowedIdleTimeInConnectionPool: .milliseconds(100) ) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: configuration) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: configuration + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } @@ -1906,7 +2381,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testAvoidLeakingTLSHandshakeCompletionPromise() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), configuration: .init(timeout: .init(connect: .milliseconds(100)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(100))) + ) let localHTTPBin = HTTPBin() let port = localHTTPBin.port XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -1953,9 +2431,14 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testValidationErrorsAreSurfaced() throws { - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .TRACE, body: .stream { _ in - self.defaultClient.eventLoopGroup.next().makeSucceededFuture(()) - }) + let defaultClient = self.defaultClient! + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .TRACE, + body: .stream { _ in + defaultClient.eventLoopGroup.next().makeSucceededFuture(()) + } + ) let runningRequest = self.defaultClient.execute(request: request) XCTAssertThrowsError(try runningRequest.wait()) { error in XCTAssertEqual(HTTPClientError.traceRequestWithBody, error as? HTTPClientError) @@ -1973,9 +2456,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { private var bodyPartsSeenSoFar = 0 private var atEnd = false - init(headPromise: EventLoopPromise, - bodyPromises: [EventLoopPromise], - endPromise: EventLoopPromise) { + init( + headPromise: EventLoopPromise, + bodyPromises: [EventLoopPromise], + endPromise: EventLoopPromise + ) { self.headPromise = headPromise self.bodyPromises = bodyPromises self.endPromise = endPromise @@ -1991,8 +2476,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { self.bodyPartsSeenSoFar += 1 self.bodyPromises.dropFirst(myNumber).first?.succeed(bytes) ?? XCTFail("ouch, too many chunks") case .end: - context.write(self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), - promise: nil) + context.write( + self.wrapOutboundOut(.head(.init(version: .init(major: 1, minor: 1), status: .ok))), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: self.endPromise) self.atEnd = true } @@ -2005,8 +2492,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { struct NotFulfilledError: Error {} self.headPromise.fail(NotFulfilledError()) - self.bodyPromises.forEach { - $0.fail(NotFulfilledError()) + for promise in self.bodyPromises { + promise.fail(NotFulfilledError()) } self.endPromise.fail(NotFulfilledError()) } @@ -2027,12 +2514,16 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let streamWriterPromise = group.next().makePromise(of: HTTPClient.Body.StreamWriter.self) func makeServer() -> Channel? { - return try? ServerBootstrap(group: group) + try? ServerBootstrap(group: group) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline().flatMap { - channel.pipeline.addHandler(HTTPServer(headPromise: headPromise, - bodyPromises: bodyPromises, - endPromise: endPromise)) + channel.pipeline.configureHTTPServerPipeline().flatMapThrowing { + try channel.pipeline.syncOperations.addHandler( + HTTPServer( + headPromise: headPromise, + bodyPromises: bodyPromises, + endPromise: endPromise + ) + ) } } .serverChannelOption(ChannelOptions.socket(.init(SOL_SOCKET), .init(SO_REUSEADDR)), value: 1) @@ -2045,13 +2536,15 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { return nil } - return try? HTTPClient.Request(url: "http://\(localAddress.ipAddress!):\(localAddress.port!)", - method: .POST, - headers: ["transfer-encoding": "chunked"], - body: .stream { streamWriter in - streamWriterPromise.succeed(streamWriter) - return sentOffAllBodyPartsPromise.futureResult - }) + return try? HTTPClient.Request( + url: "http://\(localAddress.ipAddress!):\(localAddress.port!)", + method: .POST, + headers: ["transfer-encoding": "chunked"], + body: .stream { streamWriter in + streamWriterPromise.succeed(streamWriter) + return sentOffAllBodyPartsPromise.futureResult + } + ) } guard let server = makeServer(), let request = makeRequest(server: server) else { @@ -2083,35 +2576,46 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testUploadStreamingCallinToleratedFromOtsideEL() throws { - let request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .POST, body: .stream(length: 4) { writer in - let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) - // We have to toleare callins from any thread - DispatchQueue(label: "upload-streaming").async { - writer.write(.byteBuffer(ByteBuffer(string: "1234"))).whenComplete { _ in - promise.succeed(()) + let defaultClient = self.defaultClient! + let request = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .POST, + body: .stream(contentLength: 4) { writer in + let promise = defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + // We have to toleare callins from any thread + DispatchQueue(label: "upload-streaming").async { + writer.write(.byteBuffer(ByteBuffer(string: "1234"))).whenComplete { _ in + promise.succeed(()) + } } + return promise.futureResult } - return promise.futureResult - }) + ) XCTAssertNoThrow(try self.defaultClient.execute(request: request).wait()) } func testWeHandleUsSendingACloseHeaderCorrectly() { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["connection": "close"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["connection": "close"] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2127,21 +2631,27 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testWeHandleUsReceivingACloseHeaderCorrectly() { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["X-Send-Back-Header-Connection": "close"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: ["X-Send-Back-Header-Connection": "close"] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2158,22 +2668,32 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testWeHandleUsSendingACloseHeaderAmongstOtherConnectionHeadersCorrectly() { for closeHeader in [("connection", "close"), ("CoNneCTION", "ClOSe")] { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["X-Send-Back-Header-\(closeHeader.0)": - "foo,\(closeHeader.1),bar"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: [ + "X-Send-Back-Header-\(closeHeader.0)": + "foo,\(closeHeader.1),bar" + ] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } - guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + guard + let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } - guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + guard + let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2190,22 +2710,32 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testWeHandleUsReceivingACloseHeaderAmongstOtherConnectionHeadersCorrectly() { for closeHeader in [("connection", "close"), ("CoNneCTION", "ClOSe")] { - guard let req1 = try? Request(url: self.defaultHTTPBinURLPrefix + "stats", - method: .GET, - headers: ["X-Send-Back-Header-\(closeHeader.0)": - "foo,\(closeHeader.1),bar"]), + guard + let req1 = try? Request( + url: self.defaultHTTPBinURLPrefix + "stats", + method: .GET, + headers: [ + "X-Send-Back-Header-\(closeHeader.0)": + "foo,\(closeHeader.1),bar" + ] + ), let statsBytes1 = try? self.defaultClient.execute(request: req1).wait().body, - let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) else { + let stats1 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes1) + else { XCTFail("request 1 didn't work") return } - guard let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) else { + guard + let statsBytes2 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats2 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes2) + else { XCTFail("request 2 didn't work") return } - guard let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, - let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) else { + guard + let statsBytes3 = try? self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + "stats").wait().body, + let stats3 = try? JSONDecoder().decode(RequestInfo.self, from: statsBytes3) + else { XCTFail("request 3 didn't work") return } @@ -2223,28 +2753,35 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testLoggingCorrectlyAttachesRequestInformationEvenAfterDuringRedirect() { let logStore = CollectEverythingLogHandler.LogStore() - var logger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var logger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) logger.logLevel = .trace logger[metadataKey: "custom-request-id"] = "abcd" var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "http://localhost:\(self.defaultHTTPBin.port)/redirect/target", - method: .GET, - headers: [ - "X-Target-Redirect-URL": "/get", - ] - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "http://localhost:\(self.defaultHTTPBin.port)/redirect/target", + method: .GET, + headers: [ + "X-Target-Redirect-URL": "/get" + ] + ) + ) guard let request = maybeRequest else { return } - XCTAssertNoThrow(try self.defaultClient.execute( - request: request, - eventLoop: .indifferent, - deadline: nil, - logger: logger - ).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait() + ) let logs = logStore.allEntries XCTAssertTrue(logs.allSatisfy { $0.metadata["custom-request-id"] == "abcd" }) @@ -2263,229 +2800,310 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertGreaterThan(secondRequestLogs.count, 0) XCTAssertTrue(secondRequestLogs.allSatisfy { $0.metadata["ahc-request-id"] == lastRequestID }) - logs.forEach { print($0) } + for log in logs { print(log) } } func testLoggingCorrectlyAttachesRequestInformation() { let logStore = CollectEverythingLogHandler.LogStore() - var loggerYolo001 = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var loggerYolo001 = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) loggerYolo001.logLevel = .trace loggerYolo001[metadataKey: "yolo-request-id"] = "yolo-001" - var loggerACME002 = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var loggerACME002 = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) loggerACME002.logLevel = .trace loggerACME002[metadataKey: "acme-request-id"] = "acme-002" guard let request1 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get"), - let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats"), - let request3 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "ok") else { + let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats"), + let request3 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "ok") + else { XCTFail("bad stuff, can't even make request structures") return } // === Request 1 (Yolo001) - XCTAssertNoThrow(try self.defaultClient.execute(request: request1, - eventLoop: .indifferent, - deadline: nil, - logger: loggerYolo001).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request1, + eventLoop: .indifferent, + deadline: nil, + logger: loggerYolo001 + ).wait() + ) let logsAfterReq1 = logStore.allEntries logStore.allEntries = [] // === Request 2 (Yolo001) - XCTAssertNoThrow(try self.defaultClient.execute(request: request2, - eventLoop: .indifferent, - deadline: nil, - logger: loggerYolo001).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request2, + eventLoop: .indifferent, + deadline: nil, + logger: loggerYolo001 + ).wait() + ) let logsAfterReq2 = logStore.allEntries logStore.allEntries = [] // === Request 3 (ACME002) - XCTAssertNoThrow(try self.defaultClient.execute(request: request3, - eventLoop: .indifferent, - deadline: nil, - logger: loggerACME002).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request3, + eventLoop: .indifferent, + deadline: nil, + logger: loggerACME002 + ).wait() + ) let logsAfterReq3 = logStore.allEntries logStore.allEntries = [] - // === Assertions - XCTAssertGreaterThan(logsAfterReq1.count, 0) - XCTAssertGreaterThan(logsAfterReq2.count, 0) - XCTAssertGreaterThan(logsAfterReq3.count, 0) + // === Assertions + XCTAssertGreaterThan(logsAfterReq1.count, 0) + XCTAssertGreaterThan(logsAfterReq2.count, 0) + XCTAssertGreaterThan(logsAfterReq3.count, 0) + + XCTAssert( + logsAfterReq1.allSatisfy { entry in + if let httpRequestMetadata = entry.metadata["ahc-request-id"], + let yoloRequestID = entry.metadata["yolo-request-id"] + { + XCTAssertNil(entry.metadata["acme-request-id"]) + XCTAssertEqual("yolo-001", yoloRequestID) + XCTAssertNotNil(Int(httpRequestMetadata)) + return true + } else { + XCTFail("log message doesn't contain the right IDs: \(entry)") + return false + } + } + ) + XCTAssert( + logsAfterReq1.contains { entry in + // Since a new connection must be created first we expect that the request is queued + // and log message describing this is emitted. + entry.message == "Request was queued (waiting for a connection to become available)" + && entry.level == .debug + } + ) + XCTAssert( + logsAfterReq1.contains { entry in + // After the new connection was created we expect a log message that describes that the + // request was scheduled on a connection. The connection id must be set from here on. + entry.message == "Request was scheduled on connection" + && entry.level == .debug + && entry.metadata["ahc-connection-id"] != nil + } + ) + + XCTAssert( + logsAfterReq2.allSatisfy { entry in + if let httpRequestMetadata = entry.metadata["ahc-request-id"], + let yoloRequestID = entry.metadata["yolo-request-id"] + { + XCTAssertNil(entry.metadata["acme-request-id"]) + XCTAssertEqual("yolo-001", yoloRequestID) + XCTAssertNotNil(Int(httpRequestMetadata)) + return true + } else { + XCTFail("log message doesn't contain the right IDs: \(entry)") + return false + } + } + ) + XCTAssertFalse( + logsAfterReq2.contains { entry in + entry.message == "Request was queued (waiting for a connection to become available)" + } + ) + XCTAssert( + logsAfterReq2.contains { entry in + entry.message == "Request was scheduled on connection" + && entry.level == .debug + && entry.metadata["ahc-connection-id"] != nil + } + ) - XCTAssert(logsAfterReq1.allSatisfy { entry in - if let httpRequestMetadata = entry.metadata["ahc-request-id"], - let yoloRequestID = entry.metadata["yolo-request-id"] { - XCTAssertNil(entry.metadata["acme-request-id"]) - XCTAssertEqual("yolo-001", yoloRequestID) - XCTAssertNotNil(Int(httpRequestMetadata)) - return true - } else { - XCTFail("log message doesn't contain the right IDs: \(entry)") - return false - } - }) - XCTAssert(logsAfterReq1.contains { entry in - // Since a new connection must be created first we expect that the request is queued - // and log message describing this is emitted. - entry.message == "Request was queued (waiting for a connection to become available)" - && entry.level == .debug - }) - XCTAssert(logsAfterReq1.contains { entry in - // After the new connection was created we expect a log message that describes that the - // request was scheduled on a connection. The connection id must be set from here on. - entry.message == "Request was scheduled on connection" - && entry.level == .debug - && entry.metadata["ahc-connection-id"] != nil - }) - - XCTAssert(logsAfterReq2.allSatisfy { entry in - if let httpRequestMetadata = entry.metadata["ahc-request-id"], - let yoloRequestID = entry.metadata["yolo-request-id"] { - XCTAssertNil(entry.metadata["acme-request-id"]) - XCTAssertEqual("yolo-001", yoloRequestID) - XCTAssertNotNil(Int(httpRequestMetadata)) - return true - } else { - XCTFail("log message doesn't contain the right IDs: \(entry)") - return false - } - }) - XCTAssertFalse(logsAfterReq2.contains { entry in - entry.message == "Request was queued (waiting for a connection to become available)" - }) - XCTAssert(logsAfterReq2.contains { entry in - entry.message == "Request was scheduled on connection" - && entry.level == .debug - && entry.metadata["ahc-connection-id"] != nil - }) - - XCTAssert(logsAfterReq3.allSatisfy { entry in - if let httpRequestMetadata = entry.metadata["ahc-request-id"], - let acmeRequestID = entry.metadata["acme-request-id"] { - XCTAssertNil(entry.metadata["yolo-request-id"]) - XCTAssertEqual("acme-002", acmeRequestID) - XCTAssertNotNil(Int(httpRequestMetadata)) - return true - } else { - XCTFail("log message doesn't contain the right IDs: \(entry)") - return false + XCTAssert( + logsAfterReq3.allSatisfy { entry in + if let httpRequestMetadata = entry.metadata["ahc-request-id"], + let acmeRequestID = entry.metadata["acme-request-id"] + { + XCTAssertNil(entry.metadata["yolo-request-id"]) + XCTAssertEqual("acme-002", acmeRequestID) + XCTAssertNotNil(Int(httpRequestMetadata)) + return true + } else { + XCTFail("log message doesn't contain the right IDs: \(entry)") + return false + } + } + ) + XCTAssertFalse( + logsAfterReq3.contains { entry in + entry.message == "Request was queued (waiting for a connection to become available)" + } + ) + XCTAssert( + logsAfterReq3.contains { entry in + entry.message == "Request was scheduled on connection" + && entry.level == .debug + && entry.metadata["ahc-connection-id"] != nil } - }) - XCTAssertFalse(logsAfterReq3.contains { entry in - entry.message == "Request was queued (waiting for a connection to become available)" - }) - XCTAssert(logsAfterReq3.contains { entry in - entry.message == "Request was scheduled on connection" - && entry.level == .debug - && entry.metadata["ahc-connection-id"] != nil - }) + ) } func testNothingIsLoggedAtInfoOrHigher() { let logStore = CollectEverythingLogHandler.LogStore() - var logger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var logger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) logger.logLevel = .info guard let request1 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get"), - let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats") else { + let request2 = try? HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "stats") + else { XCTFail("bad stuff, can't even make request structures") return } // === Request 1 - XCTAssertNoThrow(try self.defaultClient.execute(request: request1, - eventLoop: .indifferent, - deadline: nil, - logger: logger).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request1, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait() + ) XCTAssertEqual(0, logStore.allEntries.count) // === Request 2 - XCTAssertNoThrow(try self.defaultClient.execute(request: request2, - eventLoop: .indifferent, - deadline: nil, - logger: logger).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + request: request2, + eventLoop: .indifferent, + deadline: nil, + logger: logger + ).wait() + ) XCTAssertEqual(0, logStore.allEntries.count) // === Synthesized Request - XCTAssertNoThrow(try self.defaultClient.execute(.GET, - url: self.defaultHTTPBinURLPrefix + "get", - body: nil, - deadline: nil, - logger: logger).wait()) + XCTAssertNoThrow( + try self.defaultClient.execute( + .GET, + url: self.defaultHTTPBinURLPrefix + "get", + body: nil, + deadline: nil, + logger: logger + ).wait() + ) XCTAssertEqual(0, logStore.allEntries.count) XCTAssertEqual(0, self.backgroundLogStore.allEntries.filter { $0.level >= .info }.count) // === Synthesized Socket Path Request - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace - - let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: backgroundLogStore) + } + ) + backgroundLogger.logLevel = .trace + + let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertNoThrow(try localClient.execute(.GET, - socketPath: path, - urlPath: "get", - body: nil, - deadline: nil, - logger: logger).wait()) - XCTAssertEqual(0, logStore.allEntries.count) + XCTAssertNoThrow( + try localClient.execute( + .GET, + socketPath: path, + urlPath: "get", + body: nil, + deadline: nil, + logger: logger + ).wait() + ) + XCTAssertEqual(0, logStore.allEntries.count) - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) - }) + XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) + } + ) // === Synthesized Secure Socket Path Request - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace - - let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: backgroundLogStore) + } + ) + backgroundLogger.logLevel = .trace + + let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertNoThrow(try localClient.execute(.GET, - secureSocketPath: path, - urlPath: "get", - body: nil, - deadline: nil, - logger: logger).wait()) - XCTAssertEqual(0, logStore.allEntries.count) + XCTAssertNoThrow( + try localClient.execute( + .GET, + secureSocketPath: path, + urlPath: "get", + body: nil, + deadline: nil, + logger: logger + ).wait() + ) + XCTAssertEqual(0, logStore.allEntries.count) - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) - }) + XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .info }.count) + } + ) } func testAllMethodsLog() { func checkExpectationsWithLogger(type: String, _ body: (Logger, String) throws -> T) throws -> T { let logStore = CollectEverythingLogHandler.LogStore() - var logger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: logStore) - }) + var logger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: logStore) + } + ) logger.logLevel = .trace logger[metadataKey: "req"] = "yo-\(type)" @@ -2493,86 +3111,125 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let result = try body(logger, url) XCTAssertGreaterThan(logStore.allEntries.count, 0) - logStore.allEntries.forEach { entry in + for entry in logStore.allEntries { XCTAssertEqual("yo-\(type)", entry.metadata["req"] ?? "n/a") XCTAssertNotNil(Int(entry.metadata["ahc-request-id"] ?? "n/a")) } return result } - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "GET") { logger, url in - try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "GET") { logger, url in + try self.defaultClient.get(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "PUT") { logger, url in - try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "PUT") { logger, url in + try self.defaultClient.put(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "POST") { logger, url in - try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "POST") { logger, url in + try self.defaultClient.post(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "DELETE") { logger, url in - try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "DELETE") { logger, url in + try self.defaultClient.delete(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "PATCH") { logger, url in - try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "PATCH") { logger, url in + try self.defaultClient.patch(url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() + }.status + ) - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "CHECKOUT") { logger, url in - try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "CHECKOUT") { logger, url in + try self.defaultClient.execute(.CHECKOUT, url: self.defaultHTTPBinURLPrefix + url, logger: logger) + .wait() + }.status + ) // No background activity expected here. XCTAssertEqual(0, self.backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: backgroundLogStore) + } + ) + backgroundLogger.logLevel = .trace - let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + let localSocketPathHTTPBin = HTTPBin(bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "GET") { logger, url in - try localClient.execute(socketPath: path, urlPath: url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "GET") { logger, url in + try localClient.execute(socketPath: path, urlPath: url, logger: logger).wait() + }.status + ) - // No background activity expected here. - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) - }) + // No background activity expected here. + XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) + } + ) - XCTAssertNoThrow(try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in - let backgroundLogStore = CollectEverythingLogHandler.LogStore() - var backgroundLogger = Logger(label: "\(#function)", factory: { _ in - CollectEverythingLogHandler(logStore: backgroundLogStore) - }) - backgroundLogger.logLevel = .trace + XCTAssertNoThrow( + try TemporaryFileHelpers.withTemporaryUnixDomainSocketPathName { path in + let backgroundLogStore = CollectEverythingLogHandler.LogStore() + var backgroundLogger = Logger( + label: "\(#function)", + factory: { _ in + CollectEverythingLogHandler(logStore: backgroundLogStore) + } + ) + backgroundLogger.logLevel = .trace - let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none), - backgroundActivityLogger: backgroundLogger) - defer { - XCTAssertNoThrow(try localClient.syncShutdown()) - XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) - } + let localSocketPathHTTPBin = HTTPBin(.http1_1(ssl: true), bindTarget: .unixDomainSocket(path)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none), + backgroundActivityLogger: backgroundLogger + ) + defer { + XCTAssertNoThrow(try localClient.syncShutdown()) + XCTAssertNoThrow(try localSocketPathHTTPBin.shutdown()) + } - XCTAssertEqual(.notFound, try checkExpectationsWithLogger(type: "GET") { logger, url in - try localClient.execute(secureSocketPath: path, urlPath: url, logger: logger).wait() - }.status) + XCTAssertEqual( + .notFound, + try checkExpectationsWithLogger(type: "GET") { logger, url in + try localClient.execute(secureSocketPath: path, urlPath: url, logger: logger).wait() + }.status + ) - // No background activity expected here. - XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) - }) + // No background activity expected here. + XCTAssertEqual(0, backgroundLogStore.allEntries.filter { $0.level >= .debug }.count) + } + ) } func testClosingIdleConnectionsInPoolLogsInTheBackground() { @@ -2581,16 +3238,19 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try self.defaultClient.syncShutdown()) XCTAssertGreaterThanOrEqual(self.backgroundLogStore.allEntries.count, 0) - XCTAssert(self.backgroundLogStore.allEntries.contains { entry in - entry.message == "Shutting down connection pool" - }) - XCTAssert(self.backgroundLogStore.allEntries.allSatisfy { entry in - entry.metadata["ahc-request-id"] == nil && - entry.metadata["ahc-request"] == nil && - entry.metadata["ahc-pool-key"] != nil - }) + XCTAssert( + self.backgroundLogStore.allEntries.contains { entry in + entry.message == "Shutting down connection pool" + } + ) + XCTAssert( + self.backgroundLogStore.allEntries.allSatisfy { entry in + entry.metadata["ahc-request-id"] == nil && entry.metadata["ahc-request"] == nil + && entry.metadata["ahc-pool-key"] != nil + } + ) - self.defaultClient = nil // so it doesn't get shut down again. + self.defaultClient = nil // so it doesn't get shut down again. } func testUploadStreamingNoLength() throws { @@ -2615,8 +3275,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTFail("Unexpected part") } - XCTAssertNoThrow(try server.readInbound()) // .body - XCTAssertNoThrow(try server.readInbound()) // .end + XCTAssertNoThrow(try server.readInbound()) // .body + XCTAssertNoThrow(try server.readInbound()) // .end XCTAssertNoThrow(try server.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) XCTAssertNoThrow(try server.writeOutbound(.end(nil))) @@ -2625,17 +3285,19 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testConnectErrorPropagatedToDelegate() throws { - class TestDelegate: HTTPClientResponseDelegate { + final class TestDelegate: HTTPClientResponseDelegate { typealias Response = Void - var error: Error? + let error = NIOLockedValueBox(nil) func didFinishRequest(task: HTTPClient.Task) throws {} func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.error = error + self.error.withLockedValue { $0 = error } } } - let httpClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(timeout: .init(connect: .milliseconds(10)))) + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(timeout: .init(connect: .milliseconds(10))) + ) defer { XCTAssertNoThrow(try httpClient.syncShutdown()) @@ -2647,12 +3309,12 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertThrowsError(try httpClient.execute(request: request, delegate: delegate).wait()) { XCTAssertEqualTypeAndValue($0, HTTPClientError.connectTimeout) - XCTAssertEqualTypeAndValue(delegate.error, HTTPClientError.connectTimeout) + XCTAssertEqualTypeAndValue(delegate.error.withLockedValue { $0 }, HTTPClientError.connectTimeout) } } func testDelegateCallinsTolerateRandomEL() throws { - class TestDelegate: HTTPClientResponseDelegate { + final class TestDelegate: HTTPClientResponseDelegate { typealias Response = Void let eventLoop: EventLoop @@ -2661,11 +3323,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func didReceiveHead(task: HTTPClient.Task, _: HTTPResponseHead) -> EventLoopFuture { - return self.eventLoop.makeSucceededFuture(()) + self.eventLoop.makeSucceededFuture(()) } func didReceiveBodyPart(task: HTTPClient.Task, _: ByteBuffer) -> EventLoopFuture { - return self.eventLoop.makeSucceededFuture(()) + self.eventLoop.makeSucceededFuture(()) } func didFinishRequest(task: HTTPClient.Task) throws {} @@ -2688,8 +3350,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let request = try HTTPClient.Request(url: "http://localhost:\(httpServer.serverPort)/") let future = httpClient.execute(request: request, delegate: delegate) - XCTAssertNoThrow(try httpServer.readInbound()) // .head - XCTAssertNoThrow(try httpServer.readInbound()) // .end + XCTAssertNoThrow(try httpServer.readInbound()) // .head + XCTAssertNoThrow(try httpServer.readInbound()) // .end XCTAssertNoThrow(try httpServer.writeOutbound(.head(.init(version: .init(major: 1, minor: 1), status: .ok)))) XCTAssertNoThrow(try httpServer.writeOutbound(.body(.byteBuffer(ByteBuffer(string: "1234"))))) @@ -2698,18 +3360,58 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try future.wait()) } + func testDelegateGetsErrorsFromCreatingRequestBag() throws { + // We want to test that we propagate errors to the delegate from failures to construct the + // request bag. Those errors only come from invalid headers. + final class TestDelegate: HTTPClientResponseDelegate, Sendable { + typealias Response = Void + let error: NIOLockedValueBox = .init(nil) + func didFinishRequest(task: HTTPClient.Task) throws {} + func didReceiveError(task: HTTPClient.Task, _ error: Error) { + self.error.withLockedValue { $0 = error } + } + } + + let httpClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup) + ) + + defer { + XCTAssertNoThrow(try httpClient.syncShutdown()) + } + + // 198.51.100.254 is reserved for documentation only + var request = try HTTPClient.Request(url: "http://198.51.100.254:65535/get") + request.headers.replaceOrAdd(name: "Not-ASCII", value: "not-fine\n") + let delegate = TestDelegate() + + XCTAssertThrowsError(try httpClient.execute(request: request, delegate: delegate).wait()) { + XCTAssertEqualTypeAndValue($0, HTTPClientError.invalidHeaderFieldValues(["not-fine\n"])) + XCTAssertEqualTypeAndValue( + delegate.error.withLockedValue { $0 }, + HTTPClientError.invalidHeaderFieldValues(["not-fine\n"]) + ) + } + } + func testContentLengthTooLongFails() throws { let url = self.defaultHTTPBinURLPrefix + "post" + let defaultClient = self.defaultClient! XCTAssertThrowsError( - try self.defaultClient.execute(request: - Request(url: url, - body: .stream(length: 10) { streamWriter in - let promise = self.defaultClient.eventLoopGroup.next().makePromise(of: Void.self) + try self.defaultClient.execute( + request: + Request( + url: url, + body: .stream(contentLength: 10) { streamWriter in + let promise = defaultClient.eventLoopGroup.next().makePromise(of: Void.self) DispatchQueue(label: "content-length-test").async { streamWriter.write(.byteBuffer(ByteBuffer(string: "1"))).cascade(to: promise) } return promise.futureResult - })).wait()) { error in + } + ) + ).wait() + ) { error in XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) } // Quickly try another request and check that it works. @@ -2731,11 +3433,16 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let url = self.defaultHTTPBinURLPrefix + "post" let tooLong = "XBAD BAD BAD NOT HTTP/1.1\r\n\r\n" XCTAssertThrowsError( - try self.defaultClient.execute(request: - Request(url: url, - body: .stream(length: 1) { streamWriter in + try self.defaultClient.execute( + request: + Request( + url: url, + body: .stream(contentLength: 1) { streamWriter in streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) - })).wait()) { error in + } + ) + ).wait() + ) { error in XCTAssertEqual(error as! HTTPClientError, HTTPClientError.bodyLengthMismatch) } // Quickly try another request and check that it works. If we by accident wrote some extra bytes into the @@ -2756,7 +3463,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testBodyUploadAfterEndFails() { let url = self.defaultHTTPBinURLPrefix + "post" - func uploader(_ streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture { + let uploader = { @Sendable (_ streamWriter: HTTPClient.Body.StreamWriter) -> EventLoopFuture in let done = streamWriter.write(.byteBuffer(ByteBuffer(string: "X"))) done.recover { error in XCTFail("unexpected error \(error)") @@ -2777,7 +3484,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } var request: HTTPClient.Request? - XCTAssertNoThrow(request = try Request(url: url, body: .stream(length: 1, uploader))) + XCTAssertNoThrow(request = try Request(url: url, body: .stream(contentLength: 1, uploader))) XCTAssertThrowsError(try self.defaultClient.execute(request: XCTUnwrap(request)).wait()) { XCTAssertEqual($0 as? HTTPClientError, .writeAfterRequestSent) } @@ -2792,6 +3499,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // second connection. _ = self.defaultClient.get(url: "http://localhost:\(self.defaultHTTPBin.port)/events/10/1") + let clientGroup = self.clientGroup! var request = try HTTPClient.Request(url: "http://localhost:\(self.defaultHTTPBin.port)/wait", method: .POST) request.body = .stream { writer in // Start writing chunks so tha we will try to write after read timeout is thrown @@ -2799,8 +3507,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { _ = writer.write(.byteBuffer(ByteBuffer(string: "1234"))) } - let promise = self.clientGroup.next().makePromise(of: Void.self) - self.clientGroup.next().scheduleTask(in: .milliseconds(3)) { + let promise = clientGroup.next().makePromise(of: Void.self) + clientGroup.next().scheduleTask(in: .milliseconds(3)) { writer.write(.byteBuffer(ByteBuffer(string: "1234"))).cascade(to: promise) } @@ -2809,11 +3517,13 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // We specify a deadline of 2 ms co that request will be timed out before all chunks are writtent, // we need to verify that second error on write after timeout does not lead to double-release. - XCTAssertThrowsError(try self.defaultClient.execute(request: request, deadline: .now() + .milliseconds(2)).wait()) + XCTAssertThrowsError( + try self.defaultClient.execute(request: request, deadline: .now() + .milliseconds(2)).wait() + ) } func testSSLHandshakeErrorPropagation() throws { - class CloseHandler: ChannelInboundHandler { + final class CloseHandler: ChannelInboundHandler, Sendable { typealias InboundIn = Any func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -2870,11 +3580,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testSSLHandshakeErrorPropagationDelayedClose() throws { // This is as the test above, but the close handler delays its close action by a few hundred ms. // This will tend to catch the pipeline at different weird stages, and flush out different bugs. - class CloseHandler: ChannelInboundHandler { + final class CloseHandler: ChannelInboundHandler, Sendable { typealias InboundIn = Any func channelRead(context: ChannelHandlerContext, data: NIOAny) { - context.eventLoop.scheduleTask(in: .milliseconds(100)) { + context.eventLoop.assumeIsolated().scheduleTask(in: .milliseconds(100)) { context.close(promise: nil) } } @@ -2931,8 +3641,8 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let server = try ServerBootstrap(group: self.serverGroup) .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) .childChannelInitializer { channel in - channel.pipeline.configureHTTPServerPipeline().flatMap { - channel.pipeline.addHandler(CloseWithoutClosingServerHandler(group.leave)) + channel.pipeline.configureHTTPServerPipeline().flatMapThrowing { + try channel.pipeline.syncOperations.addHandler(CloseWithoutClosingServerHandler(group.leave)) } } .bind(host: "localhost", port: 0) @@ -2971,7 +3681,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let body: HTTPClient.Body = .stream { writer in let finalPromise = writeEL.makePromise(of: Void.self) - func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { // always invoke from the wrong el to test thread safety writeEL.preconditionInEventLoop() @@ -3021,10 +3731,12 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var request = try Request(url: httpBin.baseURL) request.body = .byteBuffer(body) - XCTAssertThrowsError(try self.defaultClient.execute( - request: request, - delegate: ResponseAccumulator(request: request, maxBodySize: 10) - ).wait()) { error in + XCTAssertThrowsError( + try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() + ) { error in XCTAssertTrue(error is ResponseAccumulator.ResponseTooBigError, "unexpected error \(error)") } } @@ -3071,10 +3783,12 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { request.body = .stream { writer in writer.write(.byteBuffer(body)) } - XCTAssertThrowsError(try self.defaultClient.execute( - request: request, - delegate: ResponseAccumulator(request: request, maxBodySize: 10) - ).wait()) { error in + XCTAssertThrowsError( + try self.defaultClient.execute( + request: request, + delegate: ResponseAccumulator(request: request, maxBodySize: 10) + ).wait() + ) { error in XCTAssertTrue(error is ResponseAccumulator.ResponseTooBigError, "unexpected error \(error)") } } @@ -3100,7 +3814,9 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // In this test, we test that a request can continue to stream its body after the response head and end // was received where the end is a 200. func testBiDirectionalStreamingEarly200() { - let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTP200DelayedHandler(bodyPartsBeforeResponse: 1) } + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in + HTTP200DelayedHandler(bodyPartsBeforeResponse: 1) + } defer { XCTAssertNoThrow(try httpBin.shutdown()) } let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) @@ -3116,7 +3832,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let body: HTTPClient.Body = .stream { writer in let finalPromise = writeEL.makePromise(of: Void.self) - func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { // always invoke from the wrong el to test thread safety writeEL.preconditionInEventLoop() @@ -3154,7 +3870,9 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // This test is identical to the one above, except that we send another request immediately after. This is a regression // test for https://github.com/swift-server/async-http-client/issues/595. func testBiDirectionalStreamingEarly200DoesntPreventUsFromSendingMoreRequests() { - let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in HTTP200DelayedHandler(bodyPartsBeforeResponse: 1) } + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in + HTTP200DelayedHandler(bodyPartsBeforeResponse: 1) + } defer { XCTAssertNoThrow(try httpBin.shutdown()) } let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) @@ -3167,7 +3885,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let body: HTTPClient.Body = .stream { writer in let finalPromise = writeEL.makePromise(of: Void.self) - func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { // always invoke from the wrong el to test thread safety writeEL.preconditionInEventLoop() @@ -3212,7 +3930,9 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } let onClosePromise = eventLoopGroup.next().makePromise(of: Void.self) - let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in ExpectClosureServerHandler(onClosePromise: onClosePromise) } + let httpBin = HTTPBin(.http1_1(ssl: false, compress: false)) { _ in + ExpectClosureServerHandler(onClosePromise: onClosePromise) + } defer { XCTAssertNoThrow(try httpBin.shutdown()) } let writeEL = eventLoopGroup.next() @@ -3223,7 +3943,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { let body: HTTPClient.Body = .stream { writer in let finalPromise = writeEL.makePromise(of: Void.self) - func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { + @Sendable func writeLoop(_ writer: HTTPClient.Body.StreamWriter, index: Int) { // always invoke from the wrong el to test thread safety writeEL.preconditionInEventLoop() @@ -3270,8 +3990,10 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { tlsConfig.maximumTLSVersion = .tlsv12 tlsConfig.certificateVerification = .none let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(tlsConfiguration: tlsConfig)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(tlsConfiguration: tlsConfig) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -3289,11 +4011,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var request = try Request(url: self.defaultHTTPBinURLPrefix + "chunked") request.headers.add(name: "Accept", value: "text/event-stream") - let progress = - try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Progress in + let response = + try TemporaryFileHelpers.withTemporaryFilePath { path -> FileDownloadDelegate.Response in let delegate = try FileDownloadDelegate(path: path) - let progress = try self.defaultClient.execute( + let response = try self.defaultClient.execute( request: request, delegate: delegate ) @@ -3301,11 +4023,15 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { try XCTAssertEqual(50, TemporaryFileHelpers.fileSize(path: path)) - return progress + return response } - XCTAssertEqual(nil, progress.totalBytes) - XCTAssertEqual(50, progress.receivedBytes) + XCTAssertEqual(.ok, response.head.status) + XCTAssertEqual("chunked", response.head.headers.first(name: "transfer-encoding")) + XCTAssertFalse(response.head.headers.contains(name: "content-length")) + + XCTAssertEqual(nil, response.totalBytes) + XCTAssertEqual(50, response.receivedBytes) } func testCloseWhileBackpressureIsExertedIsFine() throws { @@ -3346,12 +4072,16 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testRequestSpecificTLS() throws { - let configuration = HTTPClient.Configuration(tlsConfiguration: nil, - timeout: .init(), - decompression: .disabled) + let configuration = HTTPClient.Configuration( + tlsConfiguration: nil, + timeout: .init(), + decompression: .disabled + ) let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: configuration) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: configuration + ) let decoder = JSONDecoder() defer { @@ -3362,7 +4092,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // First two requests use identical TLS configurations. var tlsConfig = TLSConfiguration.makeClientConfiguration() tlsConfig.certificateVerification = .none - let firstRequest = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, tlsConfiguration: tlsConfig) + let firstRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + tlsConfiguration: tlsConfig + ) let firstResponse = try localClient.execute(request: firstRequest).wait() guard let firstBody = firstResponse.body else { XCTFail("No request body found") @@ -3370,7 +4104,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } let firstConnectionNumber = try decoder.decode(RequestInfo.self, from: firstBody).connectionNumber - let secondRequest = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, tlsConfiguration: tlsConfig) + let secondRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + tlsConfiguration: tlsConfig + ) let secondResponse = try localClient.execute(request: secondRequest).wait() guard let secondBody = secondResponse.body else { XCTFail("No request body found") @@ -3382,7 +4120,11 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { var tlsConfig2 = TLSConfiguration.makeClientConfiguration() tlsConfig2.certificateVerification = .none tlsConfig2.maximumTLSVersion = .tlsv1 - let thirdRequest = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, tlsConfiguration: tlsConfig2) + let thirdRequest = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + tlsConfiguration: tlsConfig2 + ) let thirdResponse = try localClient.execute(request: thirdRequest).wait() guard let thirdBody = thirdResponse.body else { XCTFail("No request body found") @@ -3393,8 +4135,16 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertEqual(firstResponse.status, .ok) XCTAssertEqual(secondResponse.status, .ok) XCTAssertEqual(thirdResponse.status, .ok) - XCTAssertEqual(firstConnectionNumber, secondConnectionNumber, "Identical TLS configurations did not use the same connection") - XCTAssertNotEqual(thirdConnectionNumber, firstConnectionNumber, "Different TLS configurations did not use different connections.") + XCTAssertEqual( + firstConnectionNumber, + secondConnectionNumber, + "Identical TLS configurations did not use the same connection" + ) + XCTAssertNotEqual( + thirdConnectionNumber, + firstConnectionNumber, + "Different TLS configurations did not use different connections." + ) } func testRequestWithHeaderTransferEncodingIdentityDoesNotFail() { @@ -3419,7 +4169,9 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { func testMassiveDownload() { var response: HTTPClient.Response? - XCTAssertNoThrow(response = try self.defaultClient.get(url: "\(self.defaultHTTPBinURLPrefix)mega-chunked").wait()) + XCTAssertNoThrow( + response = try self.defaultClient.get(url: "\(self.defaultHTTPBinURLPrefix)mega-chunked").wait() + ) XCTAssertEqual(.ok, response?.status) XCTAssertEqual(response?.version, .http1_1) @@ -3446,11 +4198,13 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { } func testMassiveHeaderHTTP2() throws { - let bin = HTTPBin(.http2(settings: [ - .init(parameter: .maxConcurrentStreams, value: 100), - .init(parameter: .maxHeaderListSize, value: 1024 * 256), - .init(parameter: .maxFrameSize, value: 1024 * 256), - ])) + let bin = HTTPBin( + .http2(settings: [ + .init(parameter: .maxConcurrentStreams, value: 100), + .init(parameter: .maxHeaderListSize, value: 1024 * 256), + .init(parameter: .maxFrameSize, value: 1024 * 256), + ]) + ) defer { XCTAssertNoThrow(try bin.shutdown()) } let client = HTTPClient( @@ -3473,12 +4227,76 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertNoThrow(try client.execute(request: request).wait()) } + func testCancelingRequestAfterRedirect() throws { + let request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": self.defaultHTTPBinURLPrefix + "wait"], + body: nil + ) + + final class CancelAfterRedirect: HTTPClientResponseDelegate, Sendable { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + } + + let task = defaultClient.execute( + request: request, + delegate: CancelAfterRedirect(), + deadline: .now() + .seconds(1) + ) + + // there is currently no HTTPClientResponseDelegate method to ensure the redirect occurs before we cancel, so we just sleep for 500ms + Thread.sleep(forTimeInterval: 0.5) + + task.cancel() + + XCTAssertThrowsError(try task.wait()) { error in + guard case let error = error as? HTTPClientError, error == .cancelled else { + return XCTFail("Should fail with cancelled") + } + } + } + + func testFailingRequestAfterRedirect() throws { + let request = try Request( + url: self.defaultHTTPBinURLPrefix + "redirect/target", + method: .GET, + headers: ["X-Target-Redirect-URL": self.defaultHTTPBinURLPrefix + "wait"], + body: nil + ) + + final class FailAfterRedirect: HTTPClientResponseDelegate, Sendable { + init() {} + func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} + } + + let task = defaultClient.execute( + request: request, + delegate: FailAfterRedirect(), + deadline: .now() + .seconds(1) + ) + + // there is currently no HTTPClientResponseDelegate method to ensure the redirect occurs before we fail, so we just sleep for 500ms + Thread.sleep(forTimeInterval: 0.5) + + struct TestError: Error {} + + task.fail(reason: TestError()) + + XCTAssertThrowsError(try task.wait()) { error in + guard error is TestError else { + return XCTFail("Should fail with TestError") + } + } + } + func testCancelingHTTP1RequestAfterHeaderSend() throws { var request = try HTTPClient.Request(url: self.defaultHTTPBin.baseURL + "/wait", method: .POST) // non-empty body is important request.body = .byteBuffer(ByteBuffer([1])) - class CancelAfterHeadSend: HTTPClientResponseDelegate { + final class CancelAfterHeadSend: HTTPClientResponseDelegate, Sendable { init() {} func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { @@ -3495,7 +4313,7 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // non-empty body is important request.body = .byteBuffer(ByteBuffer([1])) - class CancelAfterHeadSend: HTTPClientResponseDelegate { + final class CancelAfterHeadSend: HTTPClientResponseDelegate, Sendable { init() {} func didFinishRequest(task: AsyncHTTPClient.HTTPClient.Task) throws {} func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { @@ -3574,6 +4392,26 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { XCTAssertEqual(.ok, response.status) } + func testClientWithMultipath() throws { + do { + var conf = HTTPClient.Configuration() + conf.enableMultipath = true + let client = HTTPClient(configuration: conf) + defer { + XCTAssertNoThrow(try client.shutdown().wait()) + } + let response = try client.get(url: self.defaultHTTPBinURLPrefix + "get").wait() + XCTAssertEqual(.ok, response.status) + } catch let error as IOError + where error.errnoCode == EINVAL || error.errnoCode == EPROTONOSUPPORT || error.errnoCode == ENOPROTOOPT + { + // some old Linux kernels don't support MPTCP, skip this test in this case + // see https://www.mptcp.dev/implementation.html for details about each type + // of error + throw XCTSkip() + } + } + func testSingletonClientWorks() throws { let response = try HTTPClient.shared.get(url: self.defaultHTTPBinURLPrefix + "get").wait() XCTAssertEqual(.ok, response.status) @@ -3605,18 +4443,18 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // ! is safe, assigned above request.tlsConfiguration!.certificateVerification = .none - let response1 = try await client.execute(request, timeout: /* infinity */ .hours(99)) + let response1 = try await client.execute(request, timeout: .hours(99)) // 99h ~= infinity XCTAssertEqual(.ok, response1.status) // For the second request, we reset the TLS config request.tlsConfiguration = nil do { - let response2 = try await client.execute(request, timeout: /* infinity */ .hours(99)) + let response2 = try await client.execute(request, timeout: .hours(99)) // 99h ~= infinity XCTFail("shouldn't succeed, self-signed cert: \(response2)") } catch { switch error as? NIOSSLError { case .some(.handshakeFailed(_)): - () // ok + () // ok default: XCTFail("unexpected error: \(error)") } @@ -3627,7 +4465,184 @@ final class HTTPClientTests: XCTestCaseHTTPClientTestsBaseClass { // ! is safe, assigned above request.tlsConfiguration!.certificateVerification = .none - let response3 = try await client.execute(request, timeout: /* infinity */ .hours(99)) + let response3 = try await client.execute(request, timeout: .hours(99)) // 99h ~= infinity XCTAssertEqual(.ok, response3.status) } + + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) + func testRequestBasicAuth() async throws { + var request = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix) + request.setBasicAuth(username: "foo", password: "bar") + XCTAssertEqual(request.headers.first(name: "Authorization"), "Basic Zm9vOmJhcg==") + } + + func runBaseTestForHTTP1ConnectionDebugInitializer(ssl: Bool) { + let connectionDebugInitializerUtil = CountingDebugInitializerUtil() + + // Initializing even with just `http1_1ConnectionDebugInitializer` (rather than manually + // modifying `config`) to ensure that the matching `init` actually wires up this argument + // with the respective property. This is necessary as these parameters are defaulted and can + // be easy to miss. + var config = HTTPClient.Configuration( + http1_1ConnectionDebugInitializer: { channel in + connectionDebugInitializerUtil.initialize(channel: channel) + } + ) + config.httpVersion = .http1Only + + if ssl { + config.tlsConfiguration = .clientDefault + config.tlsConfiguration?.certificateVerification = .none + } + + let higherConnectTimeout = CountingDebugInitializerUtil.duration + .milliseconds(100) + var configWithHigherTimeout = config + configWithHigherTimeout.timeout = .init(connect: higherConnectTimeout) + + let clientWithHigherTimeout = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configWithHigherTimeout, + backgroundActivityLogger: Logger( + label: "HTTPClient", + factory: StreamLogHandler.standardOutput(label:) + ) + ) + defer { XCTAssertNoThrow(try clientWithHigherTimeout.syncShutdown()) } + + let bin = HTTPBin(.http1_1(ssl: ssl, compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let scheme = ssl ? "https" : "http" + + for _ in 0..<3 { + XCTAssertNoThrow( + try clientWithHigherTimeout.get(url: "\(scheme)://localhost:\(bin.port)/get").wait() + ) + } + + // Even though multiple requests were made, the connection debug initializer must be called + // only once. + XCTAssertEqual(connectionDebugInitializerUtil.executionCount, 1) + + let lowerConnectTimeout = CountingDebugInitializerUtil.duration - .milliseconds(100) + var configWithLowerTimeout = config + configWithLowerTimeout.timeout = .init(connect: lowerConnectTimeout) + + let clientWithLowerTimeout = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configWithLowerTimeout, + backgroundActivityLogger: Logger( + label: "HTTPClient", + factory: StreamLogHandler.standardOutput(label:) + ) + ) + defer { XCTAssertNoThrow(try clientWithLowerTimeout.syncShutdown()) } + + XCTAssertThrowsError( + try clientWithLowerTimeout.get(url: "\(scheme)://localhost:\(bin.port)/get").wait() + ) { + XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) + } + } + + func testHTTP1PlainTextConnectionDebugInitializer() { + runBaseTestForHTTP1ConnectionDebugInitializer(ssl: false) + } + + func testHTTP1EncryptedConnectionDebugInitializer() { + runBaseTestForHTTP1ConnectionDebugInitializer(ssl: true) + } + + func testHTTP2ConnectionAndStreamChannelDebugInitializers() { + let connectionDebugInitializerUtil = CountingDebugInitializerUtil() + let streamChannelDebugInitializerUtil = CountingDebugInitializerUtil() + + // Initializing even with just `http2ConnectionDebugInitializer` and + // `http2StreamChannelDebugInitializer` (rather than manually modifying `config`) to ensure + // that the matching `init` actually wires up these arguments with the respective + // properties. This is necessary as these parameters are defaulted and can be easy to miss. + var config = HTTPClient.Configuration( + http2ConnectionDebugInitializer: { channel in + connectionDebugInitializerUtil.initialize(channel: channel) + }, + http2StreamChannelDebugInitializer: { channel in + streamChannelDebugInitializerUtil.initialize(channel: channel) + } + ) + config.tlsConfiguration = .clientDefault + config.tlsConfiguration?.certificateVerification = .none + config.httpVersion = .automatic + + let higherConnectTimeout = CountingDebugInitializerUtil.duration + .milliseconds(100) + var configWithHigherTimeout = config + configWithHigherTimeout.timeout = .init(connect: higherConnectTimeout) + + let clientWithHigherTimeout = HTTPClient( + eventLoopGroupProvider: .singleton, + configuration: configWithHigherTimeout, + backgroundActivityLogger: Logger( + label: "HTTPClient", + factory: StreamLogHandler.standardOutput(label:) + ) + ) + defer { XCTAssertNoThrow(try clientWithHigherTimeout.syncShutdown()) } + + let bin = HTTPBin(.http2(compress: false)) + defer { XCTAssertNoThrow(try bin.shutdown()) } + + let numberOfRequests = 3 + + for _ in 0..(0) + var executionCount: Int { self._executionCount.withLockedValue { $0 } } + + /// The minimum time to spend running the debug initializer. + static let duration: TimeAmount = .milliseconds(300) + + /// The actual debug initializer. + func initialize(channel: Channel) -> EventLoopFuture { + self._executionCount.withLockedValue { $0 += 1 } + + let someScheduledTask = channel.eventLoop.scheduleTask(in: Self.duration) { + channel.eventLoop.makeSucceededVoidFuture() + } + + return someScheduledTask.futureResult.flatMap { $0 } + } } diff --git a/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift b/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift index 854d9092c..b63eb7cba 100644 --- a/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPClientUncleanSSLConnectionShutdownTests.swift @@ -155,7 +155,8 @@ final class HTTPClientUncleanSSLConnectionShutdownTests: XCTestCase { ) defer { XCTAssertNoThrow(try client.syncShutdown()) } - XCTAssertThrowsError(try client.get(url: "https://localhost:\(httpBin.port)/transferencodingtruncated").wait()) { + XCTAssertThrowsError(try client.get(url: "https://localhost:\(httpBin.port)/transferencodingtruncated").wait()) + { XCTAssertEqual($0 as? HTTPParserError, .invalidEOFState) } } @@ -184,7 +185,7 @@ final class HTTPBinForSSLUncleanShutdown { let serverChannel: Channel var port: Int { - return Int(self.serverChannel.localAddress!.port!) + Int(self.serverChannel.localAddress!.port!) } init() { @@ -231,61 +232,61 @@ private final class HTTPBinForSSLUncleanShutdownHandler: ChannelInboundHandler { switch req.uri { case "/nocontentlength": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - \r\n\ - foo - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + \r\n\ + foo + """ case "/nocontent": response = """ - HTTP/1.1 204 OK\r\n\ - Connection: close\r\n\ - \r\n - """ + HTTP/1.1 204 OK\r\n\ + Connection: close\r\n\ + \r\n + """ case "/noresponse": response = nil case "/wrongcontentlength": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - Content-Length: 6\r\n\ - \r\n\ - foo - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Content-Length: 6\r\n\ + \r\n\ + foo + """ case "/transferencoding": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 3\r\n\ - foo\r\n\ - 0\r\n\ - \r\n - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 3\r\n\ + foo\r\n\ + 0\r\n\ + \r\n + """ case "/transferencodingtruncated": response = """ - HTTP/1.1 200 OK\r\n\ - Connection: close\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n\ - 12\r\n\ - foo - """ + HTTP/1.1 200 OK\r\n\ + Connection: close\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n\ + 12\r\n\ + foo + """ default: response = """ - HTTP/1.1 404 OK\r\n\ - Connection: close\r\n\ - Content-Length: 9\r\n\ - \r\n\ - Not Found - """ + HTTP/1.1 404 OK\r\n\ + Connection: close\r\n\ + Content-Length: 9\r\n\ + \r\n\ + Not Found + """ } if let response = response { diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift index 476584972..15cc9e7e9 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+FactoryTests.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOPosix @@ -20,18 +19,22 @@ import NIOSOCKS import NIOSSL import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_FactoryTests: XCTestCase { func testConnectionCreationTimesoutIfDeadlineIsInThePast() { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -45,13 +48,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - requester: ExplodingRequester(), - connectionID: 1, - deadline: .now() - .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() - .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .connectTimeout) } @@ -62,12 +66,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -82,13 +88,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - requester: ExplodingRequester(), - connectionID: 1, - deadline: .now() + .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() + .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .socksHandshakeTimeout) } @@ -99,12 +106,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -119,13 +128,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - requester: ExplodingRequester(), - connectionID: 1, - deadline: .now() + .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() + .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .httpProxyHandshakeTimeout) } @@ -136,12 +146,14 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } var server: Channel? - XCTAssertNoThrow(server = try ServerBootstrap(group: group) - .childChannelInitializer { channel in - channel.pipeline.addHandler(NeverrespondServerHandler()) - } - .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) - .wait()) + XCTAssertNoThrow( + server = try ServerBootstrap(group: group) + .childChannelInitializer { channel in + channel.pipeline.addHandler(NeverrespondServerHandler()) + } + .bind(to: .init(ipAddress: "127.0.0.1", port: 0)) + .wait() + ) defer { XCTAssertNoThrow(try server?.close().wait()) } @@ -158,20 +170,21 @@ class HTTPConnectionPool_FactoryTests: XCTestCase { sslContextCache: .init() ) - XCTAssertThrowsError(try factory.makeChannel( - requester: ExplodingRequester(), - connectionID: 1, - deadline: .now() + .seconds(1), - eventLoop: group.next(), - logger: .init(label: "test") - ).wait() + XCTAssertThrowsError( + try factory.makeChannel( + requester: ExplodingRequester(), + connectionID: 1, + deadline: .now() + .seconds(1), + eventLoop: group.next(), + logger: .init(label: "test") + ).wait() ) { XCTAssertEqual($0 as? HTTPClientError, .tlsHandshakeTimeout) } } } -class NeverrespondServerHandler: ChannelInboundHandler { +final class NeverrespondServerHandler: ChannelInboundHandler, Sendable { typealias InboundIn = NIOAny func channelRead(context: ChannelHandlerContext, data: NIOAny) { @@ -181,11 +194,11 @@ class NeverrespondServerHandler: ChannelInboundHandler { /// A `HTTPConnectionRequester` that will fail a test if any of its methods are ever called. final class ExplodingRequester: HTTPConnectionRequester { - func http1ConnectionCreated(_: HTTP1Connection) { + func http1ConnectionCreated(_: HTTP1Connection.SendableView) { XCTFail("http1ConnectionCreated called unexpectedly") } - func http2ConnectionCreated(_: HTTP2Connection, maximumStreams: Int) { + func http2ConnectionCreated(_: HTTP2Connection.SendableView, maximumStreams: Int) { XCTFail("http2ConnectionCreated called unexpectedly") } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift index f1a641216..914990048 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1ConnectionsTest.swift @@ -12,15 +12,20 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCreatingConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -52,7 +57,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCreatingConnectionAndFailing() { let elg = EmbeddedEventLoopGroup(loops: 4) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -103,7 +112,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el3 = elg.next() let el4 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el2, el3, el4] { XCTAssertEqual(connections.startingGeneralPurposeConnections, 0) @@ -130,7 +143,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el2, el3, el4] { XCTAssertEqual(connections.startingGeneralPurposeConnections, 0) @@ -157,7 +174,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el4 = elg.next() let el5 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el2, el3, el4] { XCTAssertEqual(connections.startingGeneralPurposeConnections, 0) @@ -181,7 +202,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el1 = elg.next() let el2 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el1, el1, el1, el2] { let connID = connections.createNewConnection(on: el) @@ -228,7 +253,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCloseConnectionIfIdle() { let elg = EmbeddedEventLoopGroup(loops: 1) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() @@ -248,7 +277,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCloseConnectionIfIdleButLeasedRaceCondition() { let elg = EmbeddedEventLoopGroup(loops: 1) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() @@ -267,7 +300,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testCloseConnectionIfIdleButClosedRaceCondition() { let elg = EmbeddedEventLoopGroup(loops: 1) - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) let el1 = elg.next() @@ -288,7 +325,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el3 = elg.next() let el4 = elg.next() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: .init(), maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: .init(), + maximumConnectionUses: nil + ) for el in [el1, el2, el3, el4] { let connID = connections.createNewConnection(on: el) @@ -343,7 +384,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator, maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -372,7 +417,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoop() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator, maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -408,10 +457,46 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { XCTAssertTrue(context.eventLoop === el3) } + func testMigrationFromHTTP2WithPendingRequestsWithRequiredEventLoopSameAsStartingConnections() { + let elg = EmbeddedEventLoopGroup(loops: 4) + let generator = HTTPConnectionPool.Connection.ID.Generator() + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) + + let el1 = elg.next() + let el2 = elg.next() + + let conn1ID = generator.next() + let conn2ID = generator.next() + + connections.migrateFromHTTP2( + starting: [(conn1ID, el1)], + backingOff: [(conn2ID, el2)] + ) + + let stats = connections.stats + XCTAssertEqual(stats.idle, 0) + XCTAssertEqual(stats.leased, 0) + XCTAssertEqual(stats.connecting, 1) + XCTAssertEqual(stats.backingOff, 1) + + let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) + let (_, context) = connections.newHTTP1ConnectionEstablished(conn1) + XCTAssertEqual(context.use, .generalPurpose) + XCTAssertTrue(context.eventLoop === el1) + } + func testMigrationFromHTTP2WithPendingRequestsWithPreferredEventLoop() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator, maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -450,7 +535,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2WithAlreadyLeasedHTTP1Connection() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator, maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() let el3 = elg.next() @@ -494,7 +583,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2WithMoreStartingConnectionsThanMaximumAllowedConccurentConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 2, generator: generator, maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 2, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -529,7 +622,11 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { func testMigrationFromHTTP2StartsEnoghOverflowConnectionsForRequiredEventLoopRequests() { let elg = EmbeddedEventLoopGroup(loops: 4) let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 1, generator: generator, maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 1, + generator: generator, + maximumConnectionUses: nil + ) let el1 = elg.next() let el2 = elg.next() @@ -571,16 +668,23 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { let el2 = elg.next() let generator = HTTPConnectionPool.Connection.ID.Generator() - var connections = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: generator, maximumConnectionUses: nil) + var connections = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: generator, + maximumConnectionUses: nil + ) let connID1 = connections.createNewConnection(on: el1) let context = connections.migrateToHTTP2() - XCTAssertEqual(context, .init( - backingOff: [], - starting: [(connID1, el1)], - close: [] - )) + XCTAssertEqual( + context, + .init( + backingOff: [], + starting: [(connID1, el1)], + close: [] + ) + ) let connID2 = generator.next() @@ -598,8 +702,7 @@ class HTTPConnectionPool_HTTP1ConnectionsTests: XCTestCase { extension HTTPConnectionPool.HTTP1Connections.HTTP1ToHTTP2MigrationContext: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { - return lhs.close == rhs.close && - lhs.starting.elementsEqual(rhs.starting, by: { $0.0 == $1.0 && $0.1 === $1.1 }) && - lhs.backingOff.elementsEqual(rhs.backingOff, by: { $0.0 == $1.0 && $0.1 === $1.1 }) + lhs.close == rhs.close && lhs.starting.elementsEqual(rhs.starting, by: { $0.0 == $1.0 && $0.1 === $1.1 }) + && lhs.backingOff.elementsEqual(rhs.backingOff, by: { $0.0 == $1.0 && $0.1 === $1.1 }) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift index 2df63a0f3..2be6cfa26 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP1StateTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { func testCreatingAndFailingConnections() { struct SomeError: Error, Equatable {} @@ -29,6 +30,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 8, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -113,6 +115,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 8, retryConnectionEstablishment: false, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -181,6 +184,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 2, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -194,9 +198,12 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { return XCTFail("Unexpected connection action: \(action.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux - let failedConnect1 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: connectionID) + let failedConnect1 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: connectionID + ) XCTAssertEqual(failedConnect1.request, .none) guard case .scheduleBackoffTimer(connectionID, let backoffTimeAmount1, _) = failedConnect1.connection else { return XCTFail("Unexpected connection action: \(failedConnect1.connection)") @@ -209,9 +216,12 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { return XCTFail("Unexpected connection action: \(backoffDoneAction.connection)") } XCTAssertGreaterThan(newConnectionID, connectionID) - XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux - let failedConnect2 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: newConnectionID) + let failedConnect2 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: newConnectionID + ) XCTAssertEqual(failedConnect2.request, .none) guard case .scheduleBackoffTimer(newConnectionID, let backoffTimeAmount2, _) = failedConnect2.connection else { return XCTFail("Unexpected connection action: \(failedConnect2.connection)") @@ -224,7 +234,9 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .failRequest(let requestToFail, let requestError, cancelTimeout: false) = failRequest.request else { return XCTFail("Unexpected request action: \(action.request)") } - XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + + // XCTAssertIdentical not available on Linux + XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) XCTAssertEqual(requestError as? HTTPClientError, .connectTimeout) XCTAssertEqual(failRequest.connection, .none) @@ -240,6 +252,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 2, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -253,7 +266,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. cancel request @@ -265,7 +278,9 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(state.timeoutRequest(request.id), .none, "To late timeout is ignored") // 4. succeed connection attempt - let connectedAction = state.newHTTP1ConnectionCreated(.__testOnly_connection(id: connectionID, eventLoop: connectionEL)) + let connectedAction = state.newHTTP1ConnectionCreated( + .__testOnly_connection(id: connectionID, eventLoop: connectionEL) + ) XCTAssertEqual(connectedAction.request, .none, "Request must not be executed") XCTAssertEqual(connectedAction.connection, .scheduleTimeoutTimer(connectionID, on: connectionEL)) } @@ -278,6 +293,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 2, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -291,15 +307,18 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. connection succeeds - let connection: HTTPConnectionPool.Connection = .__testOnly_connection(id: connectionID, eventLoop: connectionEL) + let connection: HTTPConnectionPool.Connection = .__testOnly_connection( + id: connectionID, + eventLoop: connectionEL + ) let connectedAction = state.newHTTP1ConnectionCreated(connection) guard case .executeRequest(request, connection, cancelTimeout: true) = connectedAction.request else { return XCTFail("Unexpected request action: \(connectedAction.request)") } - XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux XCTAssertEqual(connectedAction.connection, .none) // 3. shutdown @@ -319,7 +338,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { let finalRequest = HTTPConnectionPool.Request(finalMockRequest) let failAction = state.executeRequest(finalRequest) XCTAssertEqual(failAction.connection, .none) - XCTAssertEqual(failAction.request, .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false)) + XCTAssertEqual( + failAction.request, + .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false) + ) // 5. close open connection let closeAction = state.http1ConnectionClosed(connectionID) @@ -340,7 +362,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // Add eight requests to fill all connections for _ in 0..<8 { let eventLoop = elg.next() - guard let expectedConnection = connections.newestParkedConnection(for: eventLoop) ?? connections.newestParkedConnection else { + guard + let expectedConnection = connections.newestParkedConnection(for: eventLoop) + ?? connections.newestParkedConnection + else { return XCTFail("Expected to still have connections available") } @@ -349,7 +374,8 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { let action = state.executeRequest(request) XCTAssertEqual(action.connection, .cancelTimeoutTimer(expectedConnection.id)) - guard case .executeRequest(let returnedRequest, expectedConnection, cancelTimeout: false) = action.request else { + guard case .executeRequest(let returnedRequest, expectedConnection, cancelTimeout: false) = action.request + else { return XCTFail("Expected to execute a request next, but got: \(action.request)") } @@ -423,7 +449,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // 10% of the cases enforce the eventLoop let elRequired = (0..<10).randomElement().flatMap { $0 == 0 ? true : false }! - let mockRequest = MockHTTPScheduableRequest(eventLoop: reqEventLoop, requiresEventLoopForChannel: elRequired) + let mockRequest = MockHTTPScheduableRequest( + eventLoop: reqEventLoop, + requiresEventLoopForChannel: elRequired + ) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -435,7 +464,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssert(connEventLoop === reqEventLoop) XCTAssertEqual(action.request, .scheduleRequestTimeout(for: request, on: reqEventLoop)) - let connection: HTTPConnectionPool.Connection = .__testOnly_connection(id: connectionID, eventLoop: connEventLoop) + let connection: HTTPConnectionPool.Connection = .__testOnly_connection( + id: connectionID, + eventLoop: connEventLoop + ) let createdAction = state.newHTTP1ConnectionCreated(connection) XCTAssertEqual(createdAction.request, .executeRequest(request, connection, cancelTimeout: true)) XCTAssertEqual(createdAction.connection, .none) @@ -446,7 +478,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(state.http1ConnectionClosed(connectionID), .none) case .cancelTimeoutTimer(let connectionID): - guard let expectedConnection = connections.newestParkedConnection(for: reqEventLoop) ?? connections.newestParkedConnection else { + guard + let expectedConnection = connections.newestParkedConnection(for: reqEventLoop) + ?? connections.newestParkedConnection + else { return XCTFail("Expected to have connections available") } @@ -454,7 +489,11 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssert(expectedConnection.eventLoop === reqEventLoop) } - XCTAssertEqual(connectionID, expectedConnection.id, "Request is scheduled on the connection we expected") + XCTAssertEqual( + connectionID, + expectedConnection.id, + "Request is scheduled on the connection we expected" + ) XCTAssertNoThrow(try connections.activateConnection(connectionID)) guard case .executeRequest(let request, let connection, cancelTimeout: false) = action.request else { @@ -464,8 +503,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.execute(request.__testOnly_wrapped_request(), on: connection)) XCTAssertNoThrow(try connections.finishExecution(connection.id)) - XCTAssertEqual(state.http1ConnectionReleased(connection.id), - .init(request: .none, connection: .scheduleTimeoutTimer(connection.id, on: connection.eventLoop))) + XCTAssertEqual( + state.http1ConnectionReleased(connection.id), + .init(request: .none, connection: .scheduleTimeoutTimer(connection.id, on: connection.eventLoop)) + ) XCTAssertNoThrow(try connections.parkConnection(connectionID)) default: @@ -537,7 +578,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { // Add eight requests to fill all connections for _ in 0..<8 { let eventLoop = elg.next() - guard let expectedConnection = connections.newestParkedConnection(for: eventLoop) ?? connections.newestParkedConnection else { + guard + let expectedConnection = connections.newestParkedConnection(for: eventLoop) + ?? connections.newestParkedConnection + else { return XCTFail("Expected to still have connections available") } @@ -584,12 +628,20 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { guard let newConnection = maybeNewConnection else { return XCTFail("Expected to get a new connection") } let afterRecreationAction = state.newHTTP1ConnectionCreated(newConnection) XCTAssertEqual(afterRecreationAction.connection, .none) - guard case .executeRequest(let request, newConnection, cancelTimeout: true) = afterRecreationAction.request else { + guard + case .executeRequest(let request, newConnection, cancelTimeout: true) = afterRecreationAction + .request + else { return XCTFail("Unexpected request action: \(action.request)") } XCTAssertEqual(request.id, queuedRequestsOrder.popFirst()) - XCTAssertNoThrow(try connections.execute(queuer.get(request.id, request: request.__testOnly_wrapped_request()), on: newConnection)) + XCTAssertNoThrow( + try connections.execute( + queuer.get(request.id, request: request.__testOnly_wrapped_request()), + on: newConnection + ) + ) case .none: XCTAssert(queuer.isEmpty) @@ -670,6 +722,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 6, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -710,6 +763,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 6, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -723,7 +777,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction.request, .scheduleRequestTimeout(for: request, on: mockRequest.eventLoop)) - let failAction = state.failedToCreateNewConnection(HTTPClientError.httpProxyHandshakeTimeout, connectionID: connectionID) + let failAction = state.failedToCreateNewConnection( + HTTPClientError.httpProxyHandshakeTimeout, + connectionID: connectionID + ) guard case .scheduleBackoffTimer(connectionID, backoff: _, on: let timerEL) = failAction.connection else { return XCTFail("Expected to create a backoff timer") } @@ -731,7 +788,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(failAction.request, .none) let timeoutAction = state.timeoutRequest(request.id) - XCTAssertEqual(timeoutAction.request, .failRequest(request, HTTPClientError.httpProxyHandshakeTimeout, cancelTimeout: false)) + XCTAssertEqual( + timeoutAction.request, + .failRequest(request, HTTPClientError.httpProxyHandshakeTimeout, cancelTimeout: false) + ) XCTAssertEqual(timeoutAction.connection, .none) } @@ -743,6 +803,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 6, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -756,7 +817,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction.request, .scheduleRequestTimeout(for: request, on: mockRequest.eventLoop)) let timeoutAction = state.timeoutRequest(request.id) - XCTAssertEqual(timeoutAction.request, .failRequest(request, HTTPClientError.connectTimeout, cancelTimeout: false)) + XCTAssertEqual( + timeoutAction.request, + .failRequest(request, HTTPClientError.connectTimeout, cancelTimeout: false) + ) XCTAssertEqual(timeoutAction.connection, .none) } @@ -768,6 +832,7 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 6, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -793,7 +858,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(executeAction2.request, .scheduleRequestTimeout(for: request2, on: connEL1)) - let failAction = state.failedToCreateNewConnection(HTTPClientError.httpProxyHandshakeTimeout, connectionID: connectionID1) + let failAction = state.failedToCreateNewConnection( + HTTPClientError.httpProxyHandshakeTimeout, + connectionID: connectionID1 + ) guard case .scheduleBackoffTimer(connectionID1, backoff: _, on: let timerEL) = failAction.connection else { return XCTFail("Expected to create a backoff timer") } @@ -807,7 +875,10 @@ class HTTPConnectionPool_HTTP1StateMachineTests: XCTestCase { XCTAssertEqual(createdAction.connection, .none) let timeoutAction = state.timeoutRequest(request2.id) - XCTAssertEqual(timeoutAction.request, .failRequest(request2, HTTPClientError.getConnectionFromPoolTimeout, cancelTimeout: false)) + XCTAssertEqual( + timeoutAction.request, + .failRequest(request2, HTTPClientError.getConnectionFromPoolTimeout, cancelTimeout: false) + ) XCTAssertEqual(timeoutAction.connection, .none) } } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift index 69bf62d81..dd56a9102 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2ConnectionsTest.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { func testCreatingConnections() { let elg = EmbeddedEventLoopGroup(loops: 4) @@ -32,7 +33,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el1)) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) XCTAssertEqual(conn1CreatedContext.isIdle, true) XCTAssert(conn1CreatedContext.eventLoop === el1) @@ -46,7 +50,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let conn2ID = connections.createNewConnection(on: el2) XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el2)) let conn2: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn2ID, eventLoop: el2) - let (conn2Index, conn2CreatedContext) = connections.newHTTP2ConnectionEstablished(conn2, maxConcurrentStreams: 100) + let (conn2Index, conn2CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn2, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) XCTAssertTrue(conn1CreatedContext.isIdle) XCTAssert(conn2CreatedContext.eventLoop === el2) @@ -83,7 +90,9 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssert(conn1FailContext.eventLoop === el1) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) XCTAssertFalse(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests(for: el1)) - let (replaceConn1ID, replaceConn1EL) = connections.createNewConnectionByReplacingClosedConnection(at: conn1FailIndex) + let (replaceConn1ID, replaceConn1EL) = connections.createNewConnectionByReplacingClosedConnection( + at: conn1FailIndex + ) XCTAssert(replaceConn1EL === el1) XCTAssertEqual(replaceConn1ID, 1) XCTAssertTrue(connections.hasConnectionThatCanOrWillBeAbleToExecuteRequests) @@ -336,13 +345,19 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 100) XCTAssertEqual(leasedConn1, conn1) XCTAssertEqual(leasdConnContext1.wasIdle, true) - XCTAssertNil(connections.leaseStream(onRequired: el1), "should not be able to lease stream because they are all already leased") + XCTAssertNil( + connections.leaseStream(onRequired: el1), + "should not be able to lease stream because they are all already leased" + ) let (_, releaseContext) = connections.releaseStream(conn1ID) XCTAssertFalse(releaseContext.isIdle) @@ -354,7 +369,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertEqual(leasedConn, conn1) XCTAssertEqual(leaseContext.wasIdle, false) - XCTAssertNil(connections.leaseStream(onRequired: el1), "should not be able to lease stream because they are all already leased") + XCTAssertNil( + connections.leaseStream(onRequired: el1), + "should not be able to lease stream because they are all already leased" + ) } func testGoAway() { @@ -364,7 +382,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 10) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 10 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 10) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 2) @@ -386,7 +407,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { ) ) - XCTAssertNil(connections.leaseStream(onRequired: el1), "we should not be able to lease a stream because the connection is draining") + XCTAssertNil( + connections.leaseStream(onRequired: el1), + "we should not be able to lease a stream because the connection is draining" + ) // a server can potentially send more than one connection go away and we should not crash XCTAssertTrue(connections.goAwayReceived(conn1ID)?.eventLoop === el1) @@ -445,7 +469,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 1) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 1 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 1) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 1) @@ -454,7 +481,8 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertNil(connections.leaseStream(onRequired: el1), "all streams are in use") - guard let (_, newSettingsContext1) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 2) else { + guard let (_, newSettingsContext1) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 2) + else { return XCTFail("Expected to get a new settings context") } XCTAssertEqual(newSettingsContext1.availableStreams, 1) @@ -467,7 +495,8 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { XCTAssertEqual(leasedConn2, conn1) XCTAssertEqual(leaseContext2.wasIdle, false) - guard let (_, newSettingsContext2) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 1) else { + guard let (_, newSettingsContext2) = connections.newHTTP2MaxConcurrentStreamsReceived(conn1ID, newMaxStreams: 1) + else { return XCTFail("Expected to get a new settings context") } XCTAssertEqual(newSettingsContext2.availableStreams, 0) @@ -500,7 +529,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 1) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 1 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 1) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 1) @@ -535,7 +567,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { let conn1ID = connections.createNewConnection(on: el1) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 1) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 1 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 1) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 1) XCTAssertEqual(leasedConn1, conn1) @@ -556,9 +591,11 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { starting: [(conn1ID, el1)], backingOff: [(conn2ID, el2)] ) - XCTAssertTrue(connections.createConnectionsAfterMigrationIfNeeded( - requiredEventLoopsOfPendingRequests: [el1, el2] - ).isEmpty) + XCTAssertTrue( + connections.createConnectionsAfterMigrationIfNeeded( + requiredEventLoopsOfPendingRequests: [el1, el2] + ).isEmpty + ) XCTAssertEqual( connections.stats, @@ -574,7 +611,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { ) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 2) @@ -615,7 +655,10 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { ) let conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: conn1ID, eventLoop: el1) - let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished(conn1, maxConcurrentStreams: 100) + let (conn1Index, conn1CreatedContext) = connections.newHTTP2ConnectionEstablished( + conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(conn1CreatedContext.availableStreams, 100) let (leasedConn1, leasdConnContext1) = connections.leaseStreams(at: conn1Index, count: 2) @@ -714,9 +757,12 @@ class HTTPConnectionPool_HTTP2ConnectionsTests: XCTestCase { backingOff: [(conn3ID, el3)] ) - XCTAssertTrue(connections.createConnectionsAfterMigrationIfNeeded( - requiredEventLoopsOfPendingRequests: [el1, el2, el3] - ).isEmpty, "we still have an active connection for el1 and should not create a new one") + XCTAssertTrue( + connections.createConnectionsAfterMigrationIfNeeded( + requiredEventLoopsOfPendingRequests: [el1, el2, el3] + ).isEmpty, + "we still have an active connection for el1 and should not create a new one" + ) guard let (leasedConn, _) = connections.leaseStream(onRequired: el1) else { return XCTFail("could not lease stream on el1") diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift index 30b49662a..e64fd5e71 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+HTTP2StateMachineTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + private typealias Action = HTTPConnectionPool.StateMachine.Action private typealias ConnectionAction = HTTPConnectionPool.StateMachine.ConnectionAction private typealias RequestAction = HTTPConnectionPool.StateMachine.RequestAction @@ -127,14 +128,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { /// shutdown should only close one connection let shutdownAction = state.shutdown() XCTAssertEqual(shutdownAction.request, .none) - XCTAssertEqual(shutdownAction.connection, .cleanupConnections( - .init( - close: [conn], - cancel: [], - connectBackoff: [] - ), - isShutdown: .yes(unclean: false) - )) + XCTAssertEqual( + shutdownAction.connection, + .cleanupConnections( + .init( + close: [conn], + cancel: [], + connectBackoff: [] + ), + isShutdown: .yes(unclean: false) + ) + ) } func testConnectionFailureBackoff() { @@ -158,9 +162,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { return XCTFail("Unexpected connection action: \(action.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux - let failedConnect1 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: connectionID) + let failedConnect1 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: connectionID + ) XCTAssertEqual(failedConnect1.request, .none) guard case .scheduleBackoffTimer(connectionID, let backoffTimeAmount1, _) = failedConnect1.connection else { return XCTFail("Unexpected connection action: \(failedConnect1.connection)") @@ -173,9 +180,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { return XCTFail("Unexpected connection action: \(backoffDoneAction.connection)") } XCTAssertGreaterThan(newConnectionID, connectionID) - XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === newEventLoop) // XCTAssertIdentical not available on Linux - let failedConnect2 = state.failedToCreateNewConnection(HTTPClientError.connectTimeout, connectionID: newConnectionID) + let failedConnect2 = state.failedToCreateNewConnection( + HTTPClientError.connectTimeout, + connectionID: newConnectionID + ) XCTAssertEqual(failedConnect2.request, .none) guard case .scheduleBackoffTimer(newConnectionID, let backoffTimeAmount2, _) = failedConnect2.connection else { return XCTFail("Unexpected connection action: \(failedConnect2.connection)") @@ -188,7 +198,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .failRequest(let requestToFail, let requestError, cancelTimeout: false) = failRequest.request else { return XCTFail("Unexpected request action: \(action.request)") } - XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + // XCTAssertIdentical not available on Linux + XCTAssert(requestToFail.__testOnly_wrapped_request() === mockRequest) XCTAssertEqual(requestError as? HTTPClientError, .connectTimeout) XCTAssertEqual(failRequest.connection, .none) @@ -218,7 +229,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { return XCTFail("Unexpected connection action: \(action.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. initialise shutdown let shutdownAction = state.shutdown() @@ -257,11 +268,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = action.connection else { return XCTFail("Unexpected connection action: \(action.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux let failedConnectAction = state.failedToCreateNewConnection(SomeError(), connectionID: connectionID) XCTAssertEqual(failedConnectAction.connection, .none) - guard case .failRequestsAndCancelTimeouts(let requestsToFail, let requestError) = failedConnectAction.request else { + guard case .failRequestsAndCancelTimeouts(let requestsToFail, let requestError) = failedConnectAction.request + else { return XCTFail("Unexpected request action: \(action.request)") } XCTAssertEqualTypeAndValue(requestError, SomeError()) @@ -289,7 +301,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. cancel request let cancelAction = state.cancelRequest(request.id) @@ -329,15 +341,18 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .createConnection(let connectionID, on: let connectionEL) = executeAction.connection else { return XCTFail("Unexpected connection action: \(executeAction.connection)") } - XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux + XCTAssert(connectionEL === mockRequest.eventLoop) // XCTAssertIdentical not available on Linux // 2. connection succeeds - let connection: HTTPConnectionPool.Connection = .__testOnly_connection(id: connectionID, eventLoop: connectionEL) + let connection: HTTPConnectionPool.Connection = .__testOnly_connection( + id: connectionID, + eventLoop: connectionEL + ) let connectedAction = state.newHTTP2ConnectionEstablished(connection, maxConcurrentStreams: 100) guard case .executeRequestsAndCancelTimeouts([request], connection) = connectedAction.request else { return XCTFail("Unexpected request action: \(connectedAction.request)") } - XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux + XCTAssert(request.__testOnly_wrapped_request() === mockRequest) // XCTAssertIdentical not available on Linux XCTAssertEqual(connectedAction.connection, .none) // 3. shutdown @@ -357,7 +372,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let finalRequest = HTTPConnectionPool.Request(finalMockRequest) let failAction = state.executeRequest(finalRequest) XCTAssertEqual(failAction.connection, .none) - XCTAssertEqual(failAction.request, .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false)) + XCTAssertEqual( + failAction.request, + .failRequest(finalRequest, HTTPClientError.alreadyShutdown, cancelTimeout: false) + ) // 5. close open connection let closeAction = state.http2ConnectionClosed(connectionID) @@ -416,7 +434,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { newHTTP2Connection: conn2, maxConcurrentStreams: 100 ) - XCTAssertEqual(http2ConnectAction.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual( + http2ConnectAction.connection, + .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil) + ) guard case .executeRequestsAndCancelTimeouts([request2], conn2) = http2ConnectAction.request else { return XCTFail("Unexpected request action \(http2ConnectAction.request)") } @@ -428,11 +449,17 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { let shutdownAction = http2State.shutdown() XCTAssertEqual(shutdownAction.request, .none) - XCTAssertEqual(shutdownAction.connection, .cleanupConnections(.init( - close: [conn2], - cancel: [], - connectBackoff: [] - ), isShutdown: .no)) + XCTAssertEqual( + shutdownAction.connection, + .cleanupConnections( + .init( + close: [conn2], + cancel: [], + connectBackoff: [] + ), + isShutdown: .no + ) + ) let releaseAction = http2State.http1ConnectionReleased(conn1ID) XCTAssertEqual(releaseAction.request, .none) @@ -445,7 +472,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator, maximumConnectionUses: nil) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: idGenerator, @@ -455,14 +486,22 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) - let connectAction = state.migrateFromHTTP1(http1Connections: http1Conns, requests: .init(), newHTTP2Connection: conn1, maxConcurrentStreams: 100) + let connectAction = state.migrateFromHTTP1( + http1Connections: http1Conns, + requests: .init(), + newHTTP2Connection: conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // execute request on idle connection let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) @@ -495,7 +534,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator, maximumConnectionUses: nil) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: idGenerator, @@ -505,13 +548,21 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) - let connectAction = state.migrateFromHTTP1(http1Connections: http1Conns, requests: .init(), newHTTP2Connection: conn1, maxConcurrentStreams: 100) + let connectAction = state.migrateFromHTTP1( + http1Connections: http1Conns, + requests: .init(), + newHTTP2Connection: conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // let the connection timeout let timeoutAction = state.connectionIdleTimeout(conn1ID) @@ -528,7 +579,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator, maximumConnectionUses: nil) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: idGenerator, @@ -537,13 +592,21 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maximumConnectionUses: nil ) let conn1 = HTTPConnectionPool.Connection.__testOnly_connection(id: conn1ID, eventLoop: el1) - let connectAction = state.migrateFromHTTP1(http1Connections: http1Conns, requests: .init(), newHTTP2Connection: conn1, maxConcurrentStreams: 100) + let connectAction = state.migrateFromHTTP1( + http1Connections: http1Conns, + requests: .init(), + newHTTP2Connection: conn1, + maxConcurrentStreams: 100 + ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // create new http2 connection let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el2, requiresEventLoopForChannel: true) @@ -568,7 +631,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator, maximumConnectionUses: nil) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: idGenerator, @@ -586,11 +653,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maxConcurrentStreams: 100 ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) let goAwayAction = state.http2ConnectionGoAwayReceived(conn1ID) XCTAssertEqual(goAwayAction.request, .none) @@ -603,7 +673,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator, maximumConnectionUses: nil) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: idGenerator, @@ -620,11 +694,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maxConcurrentStreams: 100 ) XCTAssertEqual(connectAction.request, .none) - XCTAssertEqual(connectAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // execute request on idle connection let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) @@ -649,7 +726,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // establish one idle http2 connection let idGenerator = HTTPConnectionPool.Connection.ID.Generator() - var http1Conns = HTTPConnectionPool.HTTP1Connections(maximumConcurrentConnections: 8, generator: idGenerator, maximumConnectionUses: nil) + var http1Conns = HTTPConnectionPool.HTTP1Connections( + maximumConcurrentConnections: 8, + generator: idGenerator, + maximumConnectionUses: nil + ) let conn1ID = http1Conns.createNewConnection(on: el1) var state = HTTPConnectionPool.HTTP2StateMachine( idGenerator: idGenerator, @@ -666,11 +747,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { maxConcurrentStreams: 1 ) XCTAssertEqual(connectAction1.request, .none) - XCTAssertEqual(connectAction1.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: (conn1ID, el1) - )) + XCTAssertEqual( + connectAction1.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: (conn1ID, el1) + ) + ) // execute request let mockRequest1 = MockHTTPScheduableRequest(eventLoop: el1) @@ -720,6 +804,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 8, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -769,11 +854,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.execute(request.__testOnly_wrapped_request(), on: conn1)) } - XCTAssertEqual(migrationAction.connection, .migration( - createConnections: [], - closeConnections: [], - scheduleTimeout: nil - )) + XCTAssertEqual( + migrationAction.connection, + .migration( + createConnections: [], + closeConnections: [], + scheduleTimeout: nil + ) + ) /// remaining connections should be closed immediately without executing any request for connID in connectionIDs.dropFirst() { @@ -811,6 +899,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 8, retryConnectionEstablishment: true, + preferHTTP1: false, maximumConnectionUses: nil ) @@ -858,6 +947,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 8, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) @@ -930,7 +1020,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .executeRequestsAndCancelTimeouts(let requests, let conn) = migrationAction.request else { return XCTFail("unexpected request action \(migrationAction.request)") } - XCTAssertEqual(migrationAction.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual( + migrationAction.connection, + .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil) + ) XCTAssertEqual(conn, http2Conn) XCTAssertEqual(requests.count, 10) @@ -998,6 +1091,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 8, retryConnectionEstablishment: true, + preferHTTP1: false, maximumConnectionUses: nil ) @@ -1014,11 +1108,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try queuer.queue(mockRequest, id: request1.id)) let http2Conn: HTTPConnectionPool.Connection = .__testOnly_connection(id: http2ConnID, eventLoop: el1) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP2(http2ConnID, maxConcurrentStreams: 10)) - let migrationAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) - guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = migrationAction1.request else { - return XCTFail("unexpected request action \(migrationAction1.request)") + let executeAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) + guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = executeAction1.request else { + return XCTFail("unexpected request action \(executeAction1.request)") } - XCTAssertEqual(migrationAction1.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual(requests.count, 1) for request in requests { XCTAssertNoThrow(try queuer.get(request.id, request: request.__testOnly_wrapped_request())) @@ -1026,14 +1120,20 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } // a request with new required event loop should create a new connection - let mockRequestWithRequiredEventLoop = MockHTTPScheduableRequest(eventLoop: el2, requiresEventLoopForChannel: true) + let mockRequestWithRequiredEventLoop = MockHTTPScheduableRequest( + eventLoop: el2, + requiresEventLoopForChannel: true + ) let requestWithRequiredEventLoop = HTTPConnectionPool.Request(mockRequestWithRequiredEventLoop) let action2 = state.executeRequest(requestWithRequiredEventLoop) guard case .createConnection(let http1ConnId, let http1EventLoop) = action2.connection else { return XCTFail("Unexpected connection action \(action2.connection)") } XCTAssertTrue(http1EventLoop === el2) - XCTAssertEqual(action2.request, .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop)) + XCTAssertEqual( + action2.request, + .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop) + ) XCTAssertNoThrow(try connections.createConnection(http1ConnId, on: el2)) XCTAssertNoThrow(try queuer.queue(mockRequestWithRequiredEventLoop, id: requestWithRequiredEventLoop.id)) @@ -1044,7 +1144,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { guard case .executeRequest(let request2, http1Conn, cancelTimeout: true) = migrationAction2.request else { return XCTFail("unexpected request action \(migrationAction2.request)") } - guard case .migration(let createConnections, closeConnections: [], scheduleTimeout: nil) = migrationAction2.connection else { + guard + case .migration(let createConnections, closeConnections: [], scheduleTimeout: nil) = migrationAction2 + .connection + else { return XCTFail("unexpected connection action \(migrationAction2.connection)") } XCTAssertEqual(createConnections.map { $0.1.id }, [el2.id]) @@ -1069,6 +1172,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 8, retryConnectionEstablishment: true, + preferHTTP1: false, maximumConnectionUses: nil ) @@ -1085,11 +1189,11 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try queuer.queue(mockRequest, id: request1.id)) let http2Conn: HTTPConnectionPool.Connection = .__testOnly_connection(id: http2ConnID, eventLoop: el1) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP2(http2ConnID, maxConcurrentStreams: 10)) - let migrationAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) - guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = migrationAction1.request else { - return XCTFail("unexpected request action \(migrationAction1.request)") + let executeAction1 = state.newHTTP2ConnectionCreated(http2Conn, maxConcurrentStreams: 10) + guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn) = executeAction1.request else { + return XCTFail("unexpected request action \(executeAction1.request)") } - XCTAssertEqual(migrationAction1.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) + XCTAssertEqual(requests.count, 1) for request in requests { XCTAssertNoThrow(try queuer.get(request.id, request: request.__testOnly_wrapped_request())) @@ -1097,14 +1201,20 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } // a request with new required event loop should create a new connection - let mockRequestWithRequiredEventLoop = MockHTTPScheduableRequest(eventLoop: el2, requiresEventLoopForChannel: true) + let mockRequestWithRequiredEventLoop = MockHTTPScheduableRequest( + eventLoop: el2, + requiresEventLoopForChannel: true + ) let requestWithRequiredEventLoop = HTTPConnectionPool.Request(mockRequestWithRequiredEventLoop) let action2 = state.executeRequest(requestWithRequiredEventLoop) guard case .createConnection(let http1ConnId, let http1EventLoop) = action2.connection else { return XCTFail("Unexpected connection action \(action2.connection)") } XCTAssertTrue(http1EventLoop === el2) - XCTAssertEqual(action2.request, .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop)) + XCTAssertEqual( + action2.request, + .scheduleRequestTimeout(for: requestWithRequiredEventLoop, on: mockRequestWithRequiredEventLoop.eventLoop) + ) XCTAssertNoThrow(try connections.createConnection(http1ConnId, on: el2)) XCTAssertNoThrow(try queuer.queue(mockRequestWithRequiredEventLoop, id: requestWithRequiredEventLoop.id)) @@ -1120,13 +1230,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } XCTAssertTrue(queuer.isEmpty) - // if we established a new http/1 connection we should migrate back to http/1, + // if we established a new http/1 connection we should migrate to http/1, // close the connection and shutdown the pool let http1Conn: HTTPConnectionPool.Connection = .__testOnly_connection(id: http1ConnId, eventLoop: el2) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP1(http1ConnId)) let migrationAction2 = state.newHTTP1ConnectionCreated(http1Conn) XCTAssertEqual(migrationAction2.request, .none) - XCTAssertEqual(migrationAction2.connection, .migration(createConnections: [], closeConnections: [http1Conn], scheduleTimeout: nil)) + XCTAssertEqual( + migrationAction2.connection, + .migration(createConnections: [], closeConnections: [http1Conn], scheduleTimeout: nil) + ) // in http/1 state, we should close idle http2 connections XCTAssertNoThrow(try connections.finishExecution(http2Conn.id)) @@ -1146,11 +1259,12 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { idGenerator: .init(), maximumConcurrentHTTP1Connections: 8, retryConnectionEstablishment: true, + preferHTTP1: false, maximumConnectionUses: nil ) var connectionIDs: [HTTPConnectionPool.Connection.ID] = [] - for el in [el1, el2, el2] { + for el in [el1, el2] { let mockRequest = MockHTTPScheduableRequest(eventLoop: el, requiresEventLoopForChannel: true) let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) @@ -1164,7 +1278,7 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try queuer.queue(mockRequest, id: request.id)) } - // fail the two connections for el2 + // fail the connection for el2 for connectionID in connectionIDs.dropFirst() { struct SomeError: Error {} XCTAssertNoThrow(try connections.failConnectionCreation(connectionID)) @@ -1177,16 +1291,14 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } let http2ConnID1 = connectionIDs[0] let http2ConnID2 = connectionIDs[1] - let http2ConnID3 = connectionIDs[2] // let the first connection on el1 succeed as a http2 connection let http2Conn1: HTTPConnectionPool.Connection = .__testOnly_connection(id: http2ConnID1, eventLoop: el1) XCTAssertNoThrow(try connections.succeedConnectionCreationHTTP2(http2ConnID1, maxConcurrentStreams: 10)) - let migrationAction1 = state.newHTTP2ConnectionCreated(http2Conn1, maxConcurrentStreams: 10) - guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn1) = migrationAction1.request else { - return XCTFail("unexpected request action \(migrationAction1.request)") + let connectionAction = state.newHTTP2ConnectionCreated(http2Conn1, maxConcurrentStreams: 10) + guard case .executeRequestsAndCancelTimeouts(let requests, http2Conn1) = connectionAction.request else { + return XCTFail("unexpected request action \(connectionAction.request)") } - XCTAssertEqual(migrationAction1.connection, .migration(createConnections: [], closeConnections: [], scheduleTimeout: nil)) XCTAssertEqual(requests.count, 1) for request in requests { XCTAssertNoThrow(try queuer.get(request.id, request: request.__testOnly_wrapped_request())) @@ -1205,14 +1317,6 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { } XCTAssertTrue(eventLoop2 === el2) XCTAssertNoThrow(try connections.createConnection(newHttp2ConnID2, on: el2)) - - // we now have a starting connection for el2 and another one backing off - - // if the backoff timer fires now for a connection on el2, we should *not* start a new connection - XCTAssertNoThrow(try connections.connectionBackoffTimerDone(http2ConnID3)) - let action3 = state.connectionCreationBackoffDone(http2ConnID3) - XCTAssertEqual(action3.request, .none) - XCTAssertEqual(action3.connection, .none) } func testMaxConcurrentStreamsIsRespected() { @@ -1238,10 +1342,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { case 0: XCTAssertEqual(executeAction.connection, .cancelTimeoutTimer(generalPurposeConnection.id)) XCTAssertNoThrow(try connections.activateConnection(generalPurposeConnection.id)) - XCTAssertEqual(executeAction.request, .executeRequest(request, generalPurposeConnection, cancelTimeout: false)) + XCTAssertEqual( + executeAction.request, + .executeRequest(request, generalPurposeConnection, cancelTimeout: false) + ) XCTAssertNoThrow(try connections.execute(mockRequest, on: generalPurposeConnection)) case 1..<100: - XCTAssertEqual(executeAction.request, .executeRequest(request, generalPurposeConnection, cancelTimeout: false)) + XCTAssertEqual( + executeAction.request, + .executeRequest(request, generalPurposeConnection, cancelTimeout: false) + ) XCTAssertEqual(executeAction.connection, .none) XCTAssertNoThrow(try connections.execute(mockRequest, on: generalPurposeConnection)) case 100..<1000: @@ -1259,7 +1369,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.finishExecution(generalPurposeConnection.id)) let finishAction = state.http2ConnectionStreamClosed(generalPurposeConnection.id) XCTAssertEqual(finishAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request else { + guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request + else { return XCTFail("Unexpected request action: \(finishAction.request)") } guard requests.count == 1, let request = requests.first else { @@ -1274,11 +1385,23 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // Next the server allows for more concurrent streams let newMaxStreams = 200 - XCTAssertNoThrow(try connections.newHTTP2ConnectionSettingsReceived(generalPurposeConnection.id, maxConcurrentStreams: newMaxStreams)) - let newMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived(generalPurposeConnection.id, newMaxStreams: newMaxStreams) + XCTAssertNoThrow( + try connections.newHTTP2ConnectionSettingsReceived( + generalPurposeConnection.id, + maxConcurrentStreams: newMaxStreams + ) + ) + let newMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived( + generalPurposeConnection.id, + newMaxStreams: newMaxStreams + ) XCTAssertEqual(newMaxStreamsAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = newMaxStreamsAction.request else { - return XCTFail("Unexpected request action after new max concurrent stream setting: \(newMaxStreamsAction.request)") + guard + case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = newMaxStreamsAction.request + else { + return XCTFail( + "Unexpected request action after new max concurrent stream setting: \(newMaxStreamsAction.request)" + ) } XCTAssertEqual(requests.count, 100, "Expected to execute 100 more requests") for request in requests { @@ -1295,7 +1418,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.finishExecution(generalPurposeConnection.id)) let finishAction = state.http2ConnectionStreamClosed(generalPurposeConnection.id) XCTAssertEqual(finishAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request else { + guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request + else { return XCTFail("Unexpected request action: \(finishAction.request)") } guard requests.count == 1, let request = requests.first else { @@ -1308,8 +1432,16 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { // Next the server allows for fewer concurrent streams let fewerMaxStreams = 50 - XCTAssertNoThrow(try connections.newHTTP2ConnectionSettingsReceived(generalPurposeConnection.id, maxConcurrentStreams: fewerMaxStreams)) - let fewerMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived(generalPurposeConnection.id, newMaxStreams: fewerMaxStreams) + XCTAssertNoThrow( + try connections.newHTTP2ConnectionSettingsReceived( + generalPurposeConnection.id, + maxConcurrentStreams: fewerMaxStreams + ) + ) + let fewerMaxStreamsAction = state.newHTTP2MaxConcurrentStreamsReceived( + generalPurposeConnection.id, + newMaxStreams: fewerMaxStreams + ) XCTAssertEqual(fewerMaxStreamsAction.connection, .none) XCTAssertEqual(fewerMaxStreamsAction.request, .none) @@ -1327,7 +1459,8 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { XCTAssertNoThrow(try connections.finishExecution(generalPurposeConnection.id)) let finishAction = state.http2ConnectionStreamClosed(generalPurposeConnection.id) XCTAssertEqual(finishAction.connection, .none) - guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request else { + guard case .executeRequestsAndCancelTimeouts(let requests, generalPurposeConnection) = finishAction.request + else { return XCTFail("Unexpected request action: \(finishAction.request)") } guard requests.count == 1, let request = requests.first else { @@ -1347,7 +1480,10 @@ class HTTPConnectionPool_HTTP2StateMachineTests: XCTestCase { switch remaining { case 1: timeoutTimerScheduled = true - XCTAssertEqual(finishAction.connection, .scheduleTimeoutTimer(generalPurposeConnection.id, on: generalPurposeConnection.eventLoop)) + XCTAssertEqual( + finishAction.connection, + .scheduleTimeoutTimer(generalPurposeConnection.id, on: generalPurposeConnection.eventLoop) + ) XCTAssertNoThrow(try connections.parkConnection(generalPurposeConnection.id)) case 2...50: XCTAssertEqual(finishAction.connection, .none) @@ -1392,13 +1528,17 @@ func XCTAssertEqualTypeAndValue( file: StaticString = #filePath, line: UInt = #line ) { - XCTAssertNoThrow(try { - let lhs = try lhs() - let rhs = try rhs() - guard let lhsAsRhs = lhs as? Right else { - XCTFail("could not cast \(lhs) of type \(type(of: lhs)) to \(type(of: rhs))", file: file, line: line) - return - } - XCTAssertEqual(lhsAsRhs, rhs, file: file, line: line) - }(), file: file, line: line) + XCTAssertNoThrow( + try { + let lhs = try lhs() + let rhs = try rhs() + guard let lhsAsRhs = lhs as? Right else { + XCTFail("could not cast \(lhs) of type \(type(of: lhs)) to \(type(of: rhs))", file: file, line: line) + return + } + XCTAssertEqual(lhsAsRhs, rhs, file: file, line: line) + }(), + file: file, + line: line + ) } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift index d84e7f442..724c00b1f 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+ManagerTests.swift @@ -12,12 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Logging import NIOCore import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_ManagerTests: XCTestCase { func testManagerHappyPath() { let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 4) @@ -49,15 +51,17 @@ class HTTPConnectionPool_ManagerTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -105,15 +109,17 @@ class HTTPConnectionPool_ManagerTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift index f8d6044cd..4f4bbd785 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+RequestQueueTests.swift @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOEmbedded @@ -20,6 +19,8 @@ import NIOHTTP1 import NIOSSL import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPool_RequestQueueTests: XCTestCase { func testCountAndIsEmptyWorks() { var queue = HTTPConnectionPool.RequestQueue() @@ -82,7 +83,7 @@ class HTTPConnectionPool_RequestQueueTests: XCTestCase { } } -private class MockScheduledRequest: HTTPSchedulableRequest { +final private class MockScheduledRequest: HTTPSchedulableRequest { let requiredEventLoop: EventLoop? init(requiredEventLoop: EventLoop?) { diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift index 53bba940c..bd9752d5d 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPool+StateTestUtils.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Atomics import Dispatch import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +@testable import AsyncHTTPClient + /// An `EventLoopGroup` of `EmbeddedEventLoop`s. final class EmbeddedEventLoopGroup: EventLoopGroup { private let loops: [EmbeddedEventLoop] @@ -34,7 +35,7 @@ final class EmbeddedEventLoopGroup: EventLoopGroup { } internal func makeIterator() -> EventLoopIterator { - return EventLoopIterator(self.loops) + EventLoopIterator(self.loops) } internal func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) { @@ -56,7 +57,7 @@ final class EmbeddedEventLoopGroup: EventLoopGroup { extension HTTPConnectionPool.Request: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { - return lhs.id == rhs.id + lhs.id == rhs.id } } @@ -78,15 +79,24 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction: Equatable { switch (lhs, rhs) { case (.createConnection(let lhsConnID, on: let lhsEL), .createConnection(let rhsConnID, on: let rhsEL)): return lhsConnID == rhsConnID && lhsEL === rhsEL - case (.scheduleBackoffTimer(let lhsConnID, let lhsBackoff, on: let lhsEL), .scheduleBackoffTimer(let rhsConnID, let rhsBackoff, on: let rhsEL)): + case ( + .scheduleBackoffTimer(let lhsConnID, let lhsBackoff, on: let lhsEL), + .scheduleBackoffTimer(let rhsConnID, let rhsBackoff, on: let rhsEL) + ): return lhsConnID == rhsConnID && lhsBackoff == rhsBackoff && lhsEL === rhsEL case (.scheduleTimeoutTimer(let lhsConnID, on: let lhsEL), .scheduleTimeoutTimer(let rhsConnID, on: let rhsEL)): return lhsConnID == rhsConnID && lhsEL === rhsEL case (.cancelTimeoutTimer(let lhsConnID), .cancelTimeoutTimer(let rhsConnID)): return lhsConnID == rhsConnID - case (.closeConnection(let lhsConn, isShutdown: let lhsShut), .closeConnection(let rhsConn, isShutdown: let rhsShut)): + case ( + .closeConnection(let lhsConn, isShutdown: let lhsShut), + .closeConnection(let rhsConn, isShutdown: let rhsShut) + ): return lhsConn == rhsConn && lhsShut == rhsShut - case (.cleanupConnections(let lhsContext, isShutdown: let lhsShut), .cleanupConnections(let rhsContext, isShutdown: let rhsShut)): + case ( + .cleanupConnections(let lhsContext, isShutdown: let lhsShut), + .cleanupConnections(let rhsContext, isShutdown: let rhsShut) + ): return lhsContext == rhsContext && lhsShut == rhsShut case ( .migration( @@ -100,12 +110,13 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction: Equatable { let rhsScheduleTimeout ) ): - return lhsCreateConnections.elementsEqual(rhsCreateConnections, by: { - $0.0 == $1.0 && $0.1 === $1.1 - }) && - lhsCloseConnections == rhsCloseConnections && - lhsScheduleTimeout?.0 == rhsScheduleTimeout?.0 && - lhsScheduleTimeout?.1 === rhsScheduleTimeout?.1 + return lhsCreateConnections.elementsEqual( + rhsCreateConnections, + by: { + $0.0 == $1.0 && $0.1 === $1.1 + } + ) && lhsCloseConnections == rhsCloseConnections && lhsScheduleTimeout?.0 == rhsScheduleTimeout?.0 + && lhsScheduleTimeout?.1 === rhsScheduleTimeout?.1 case (.none, .none): return true default: @@ -117,15 +128,27 @@ extension HTTPConnectionPool.StateMachine.ConnectionAction: Equatable { extension HTTPConnectionPool.StateMachine.RequestAction: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { - case (.executeRequest(let lhsReq, let lhsConn, let lhsReqID), .executeRequest(let rhsReq, let rhsConn, let rhsReqID)): + case ( + .executeRequest(let lhsReq, let lhsConn, let lhsReqID), + .executeRequest(let rhsReq, let rhsConn, let rhsReqID) + ): return lhsReq == rhsReq && lhsConn == rhsConn && lhsReqID == rhsReqID - case (.executeRequestsAndCancelTimeouts(let lhsReqs, let lhsConn), .executeRequestsAndCancelTimeouts(let rhsReqs, let rhsConn)): + case ( + .executeRequestsAndCancelTimeouts(let lhsReqs, let lhsConn), + .executeRequestsAndCancelTimeouts(let rhsReqs, let rhsConn) + ): return lhsReqs.elementsEqual(rhsReqs, by: { $0 == $1 }) && lhsConn == rhsConn - case (.failRequest(let lhsReq, _, cancelTimeout: let lhsReqID), .failRequest(let rhsReq, _, cancelTimeout: let rhsReqID)): + case ( + .failRequest(let lhsReq, _, cancelTimeout: let lhsReqID), + .failRequest(let rhsReq, _, cancelTimeout: let rhsReqID) + ): return lhsReq == rhsReq && lhsReqID == rhsReqID case (.failRequestsAndCancelTimeouts(let lhsReqs, _), .failRequestsAndCancelTimeouts(let rhsReqs, _)): return lhsReqs.elementsEqual(rhsReqs, by: { $0 == $1 }) - case (.scheduleRequestTimeout(for: let lhsReq, on: let lhsEL), .scheduleRequestTimeout(for: let rhsReq, on: let rhsEL)): + case ( + .scheduleRequestTimeout(for: let lhsReq, on: let lhsEL), + .scheduleRequestTimeout(for: let rhsReq, on: let rhsEL) + ): return lhsReq == rhsReq && lhsEL === rhsEL case (.none, .none): return true @@ -146,7 +169,10 @@ extension HTTPConnectionPool.HTTP2StateMachine.EstablishedConnectionAction: Equa switch (lhs, rhs) { case (.scheduleTimeoutTimer(let lhsConnID, on: let lhsEL), .scheduleTimeoutTimer(let rhsConnID, on: let rhsEL)): return lhsConnID == rhsConnID && lhsEL === rhsEL - case (.closeConnection(let lhsConn, isShutdown: let lhsShut), .closeConnection(let rhsConn, isShutdown: let rhsShut)): + case ( + .closeConnection(let lhsConn, isShutdown: let lhsShut), + .closeConnection(let rhsConn, isShutdown: let rhsShut) + ): return lhsConn == rhsConn && lhsShut == rhsShut case (.none, .none): return true diff --git a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift index 2cf222afe..a40703456 100644 --- a/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPConnectionPoolTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + class HTTPConnectionPoolTests: XCTestCase { func testOnlyOneConnectionIsUsedForSubSequentRequests() { let httpBin = HTTPBin() @@ -53,15 +54,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -82,7 +85,6 @@ class HTTPConnectionPoolTests: XCTestCase { let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)") let poolDelegate = TestDelegate(eventLoop: eventLoop) - let pool = HTTPConnectionPool( eventLoopGroup: eventLoopGroup, sslContextCache: .init(), @@ -93,6 +95,74 @@ class HTTPConnectionPoolTests: XCTestCase { idGenerator: .init(), backgroundActivityLogger: .init(label: "test") ) + defer { + pool.shutdown() + XCTAssertNoThrow(try poolDelegate.future.wait()) + XCTAssertNoThrow(try eventLoop.scheduleTask(in: .milliseconds(100)) {}.futureResult.wait()) + XCTAssertEqual(httpBin.activeConnections, 0) + // Since we would migrate from h2 -> h1, which creates a general purpose connection + // for every connection in .starting state, after the first request which will + // be serviced by an overflow connection, the rest of requests will use the general + // purpose connection since they are all on the same event loop. + // Hence we will only create 1 overflow connection and 1 general purpose connection. + XCTAssertEqual(httpBin.createdConnections, 2) + } + + XCTAssertEqual(httpBin.createdConnections, 0) + + for _ in 0..<10 { + var maybeRequest: HTTPClient.Request? + var maybeRequestBag: RequestBag? + XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .init( + .testOnly_exact(channelOn: eventLoopGroup.next(), delegateOn: eventLoopGroup.next()) + ), + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) + + guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } + + pool.executeRequest(requestBag) + XCTAssertNoThrow(try requestBag.task.futureResult.wait()) + + // Flakiness Alert: We check <= and >= instead of == + // While migration from h2 -> h1, one general purpose and one over flow connection + // will be created, there's no guarantee as to whether the request is executed + // after both are created. + XCTAssertGreaterThanOrEqual(httpBin.createdConnections, 1) + XCTAssertLessThanOrEqual(httpBin.createdConnections, 2) + } + } + + func testConnectionsForEventLoopRequirementsAreClosedH1Only() { + let httpBin = HTTPBin() + defer { XCTAssertNoThrow(try httpBin.shutdown()) } + let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 2) + let eventLoop = eventLoopGroup.next() + defer { XCTAssertNoThrow(try eventLoopGroup.syncShutdownGracefully()) } + + let request = try! HTTPClient.Request(url: "http://localhost:\(httpBin.port)") + let poolDelegate = TestDelegate(eventLoop: eventLoop) + var configuration = HTTPClient.Configuration() + configuration.httpVersion = .http1Only + let pool = HTTPConnectionPool( + eventLoopGroup: eventLoopGroup, + sslContextCache: .init(), + tlsConfiguration: .none, + clientConfiguration: configuration, + key: .init(request), + delegate: poolDelegate, + idGenerator: .init(), + backgroundActivityLogger: .init(label: "test") + ) defer { pool.shutdown() XCTAssertNoThrow(try poolDelegate.future.wait()) @@ -107,15 +177,19 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .init(.testOnly_exact(channelOn: eventLoopGroup.next(), delegateOn: eventLoopGroup.next())), - task: .init(eventLoop: eventLoop, logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .init( + .testOnly_exact(channelOn: eventLoopGroup.next(), delegateOn: eventLoopGroup.next()) + ), + task: .init(eventLoop: eventLoop, logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -162,15 +236,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .distantFuture, - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .distantFuture, + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -216,15 +292,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -264,15 +342,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -320,15 +400,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -366,15 +448,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequest: HTTPClient.Request? var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/wait")) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: .init(label: "test")), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } @@ -426,15 +510,17 @@ class HTTPConnectionPoolTests: XCTestCase { var maybeRequestBag: RequestBag? XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: url)) - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: XCTUnwrap(maybeRequest), - eventLoopPreference: .indifferent, - task: .init(eventLoop: eventLoopGroup.next(), logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(5), - requestOptions: .forTests(), - delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: XCTUnwrap(maybeRequest), + eventLoopPreference: .indifferent, + task: .init(eventLoop: eventLoopGroup.next(), logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(5), + requestOptions: .forTests(), + delegate: ResponseAccumulator(request: XCTUnwrap(maybeRequest)) + ) + ) guard let requestBag = maybeRequestBag else { return XCTFail("Expected to get a request") } pool.executeRequest(requestBag) @@ -458,7 +544,10 @@ class HTTPConnectionPoolTests: XCTestCase { var backoff = HTTPConnectionPool.calculateBackoff(failedAttempt: 1) // The value should be 100ms±3ms - XCTAssertLessThanOrEqual((backoff - .milliseconds(100)).nanoseconds.magnitude, TimeAmount.milliseconds(3).nanoseconds.magnitude) + XCTAssertLessThanOrEqual( + (backoff - .milliseconds(100)).nanoseconds.magnitude, + TimeAmount.milliseconds(3).nanoseconds.magnitude + ) // Should always increase // We stop when we get within the jitter of 60s, which is 1.8s @@ -474,7 +563,8 @@ class HTTPConnectionPoolTests: XCTestCase { // Ok, now we should be able to do a hundred increments, and always hit 60s, plus or minus 1.8s of jitter. for offset in 0..<100 { XCTAssertLessThanOrEqual( - (HTTPConnectionPool.calculateBackoff(failedAttempt: attempt + offset) - .seconds(60)).nanoseconds.magnitude, + (HTTPConnectionPool.calculateBackoff(failedAttempt: attempt + offset) - .seconds(60)).nanoseconds + .magnitude, TimeAmount.milliseconds(1800).nanoseconds.magnitude ) } diff --git a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift index 92bf42b1d..8fe879745 100644 --- a/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/HTTPRequestStateMachineTests.swift @@ -12,22 +12,29 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import NIOSSL import XCTest +@testable import AsyncHTTPClient + class HTTPRequestStateMachineTests: XCTestCase { func testSimpleGETRequest() { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -36,10 +43,21 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTRequestWithWriterBackpressure() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "4")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [1])) let part2 = IOData.byteBuffer(ByteBuffer(bytes: [2])) @@ -62,7 +80,10 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.requestStreamFinished(promise: nil), .sendRequestEnd(nil)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -71,14 +92,25 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTContentLengthIsTooLong() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "4")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "4")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) let part1 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) - state.requestStreamPartReceived(part1, promise: nil).assertFailRequest(HTTPClientError.bodyLengthMismatch, .close(nil)) + state.requestStreamPartReceived(part1, promise: nil).assertFailRequest( + HTTPClientError.bodyLengthMismatch, + .close(nil) + ) // if another error happens the new one is ignored XCTAssertEqual(state.errorHappened(HTTPClientError.remoteConnectionClosed), .wait) @@ -86,9 +118,17 @@ class HTTPRequestStateMachineTests: XCTestCase { func testPOSTContentLengthIsTooShort() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "8")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "8")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(8)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) @@ -97,28 +137,51 @@ class HTTPRequestStateMachineTests: XCTestCase { func testRequestBodyStreamIsCancelledIfServerRespondsWith301() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: true)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: true) + ) XCTAssertEqual(state.writabilityChanged(writable: false), .wait) XCTAssertEqual(state.writabilityChanged(writable: true), .wait) - XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) - XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), - "Expected to drop all stream data after having received a response head, with status >= 300") - - XCTAssertEqual(state.requestStreamFinished(promise: nil), .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) + + XCTAssertEqual( + state.requestStreamFinished(promise: nil), + .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) } func testStreamPartReceived_whenCancelled() { @@ -126,47 +189,84 @@ class HTTPRequestStateMachineTests: XCTestCase { let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestCancelled(), .failRequest(HTTPClientError.cancelled, .none)) - XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.cancelled, nil), - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.cancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) } func testRequestBodyStreamIsCancelledIfServerRespondsWith301WhileWriteBackpressure() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) - XCTAssertEqual(state.headSent(), .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) + XCTAssertEqual( + state.headSent(), + .notifyRequestHeadSendSuccessfully(resumeRequestBodyStream: true, startIdleTimer: false) + ) let part = IOData.byteBuffer(ByteBuffer(bytes: [0, 1, 2, 3])) XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .sendBodyPart(part, nil)) XCTAssertEqual(state.writabilityChanged(writable: false), .pauseRequestBodyStream) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .movedPermanently) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.writabilityChanged(writable: true), .wait) - XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.close, .init())) - XCTAssertEqual(state.requestStreamPartReceived(part, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), - "Expected to drop all stream data after having received a response head, with status >= 300") - - XCTAssertEqual(state.requestStreamFinished(promise: nil), .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), - "Expected to drop all stream data after having received a response head, with status >= 300") + XCTAssertEqual( + state.requestStreamPartReceived(part, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) + + XCTAssertEqual( + state.requestStreamFinished(promise: nil), + .failSendStreamFinished(HTTPClientError.requestStreamCancelled, nil), + "Expected to drop all stream data after having received a response head, with status >= 300" + ) } func testRequestBodyStreamIsContinuedIfServerRespondsWith200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) @@ -175,20 +275,34 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .sendBodyPart(part2, nil)) XCTAssertEqual(state.requestStreamFinished(promise: nil), .succeedRequest(.sendRequestEnd(nil), .init())) - XCTAssertEqual(state.requestStreamPartReceived(part2, promise: nil), .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil)) + XCTAssertEqual( + state.requestStreamPartReceived(part2, promise: nil), + .failSendBodyPart(HTTPClientError.requestStreamCancelled, nil) + ) } func testRequestBodyStreamIsContinuedIfServerSendHeadWithStatus200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) @@ -201,15 +315,26 @@ class HTTPRequestStateMachineTests: XCTestCase { func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerRespondedWith200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .forwardResponseBodyParts(.init())) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) @@ -220,15 +345,26 @@ class HTTPRequestStateMachineTests: XCTestCase { func testRequestIsFailedIfRequestBodySizeIsWrongEvenAfterServerSendHeadWithStatus200() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/", headers: HTTPHeaders([("content-length", "12")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .POST, + uri: "/", + headers: HTTPHeaders([("content-length", "12")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(12)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part0 = IOData.byteBuffer(ByteBuffer(bytes: 0...3)) XCTAssertEqual(state.requestStreamPartReceived(part0, promise: nil), .sendBodyPart(part0, nil)) // response is coming before having send all data let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part1 = IOData.byteBuffer(ByteBuffer(bytes: 4...7)) XCTAssertEqual(state.requestStreamPartReceived(part1, promise: nil), .sendBodyPart(part1, nil)) @@ -245,7 +381,10 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.writabilityChanged(writable: true), .sendRequestHead(requestHead, sendEnd: true)) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -264,10 +403,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -291,10 +440,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -318,10 +477,20 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) let part1 = ByteBuffer(bytes: 4...7) let part2 = ByteBuffer(bytes: 8...11) @@ -336,7 +505,11 @@ class HTTPRequestStateMachineTests: XCTestCase { XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelRead(.body(part2)), .wait) XCTAssertEqual(state.read(), .read, "Calling `read` while we wait for a channelReadComplete doesn't crash") - XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait, "Calling `demandMoreResponseBodyParts` while we wait for a channelReadComplete doesn't crash") + XCTAssertEqual( + state.demandMoreResponseBodyParts(), + .wait, + "Calling `demandMoreResponseBodyParts` while we wait for a channelReadComplete doesn't crash" + ) XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts(.init([part2]))) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) @@ -365,11 +538,17 @@ class HTTPRequestStateMachineTests: XCTestCase { // --- sending request let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) // --- receiving response let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "4"]) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let responseBody = ByteBuffer(bytes: [1, 2, 3, 4]) XCTAssertEqual(state.channelRead(.body(responseBody)), .wait) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init([responseBody]))) @@ -380,27 +559,51 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) } func testRemoteSuddenlyClosesTheConnection() { var state = HTTPRequestStateMachine(isChannelWritable: true) - let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/", headers: .init([("content-length", "4")])) + let requestHead = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: "/", + headers: .init([("content-length", "4")]) + ) let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(4)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) state.requestCancelled().assertFailRequest(HTTPClientError.cancelled, .close(nil)) - XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3)), promise: nil), .failSendBodyPart(HTTPClientError.cancelled, nil)) + XCTAssertEqual( + state.requestStreamPartReceived(.byteBuffer(.init(bytes: 1...3)), promise: nil), + .failSendBodyPart(HTTPClientError.cancelled, nil) + ) } func testReadTimeoutLeadsToFailureWithEverythingAfterBeingIgnored() { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) - - let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: HTTPHeaders([("content-length", "12")])) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .ok, + headers: HTTPHeaders([("content-length", "12")]) + ) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) let part0 = ByteBuffer(bytes: 0...3) XCTAssertEqual(state.channelRead(.body(part0)), .wait) state.idleReadTimeoutTriggered().assertFailRequest(HTTPClientError.readTimeout, .close(nil)) @@ -414,13 +617,19 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let continueHead = HTTPResponseHead(version: .http1_1, status: .continue) XCTAssertEqual(state.channelRead(.head(continueHead)), .wait) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -430,10 +639,16 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) XCTAssertEqual(state.idleReadTimeoutTriggered(), .wait, "A read timeout that fires to late must be ignored") } @@ -442,10 +657,16 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.channelRead(.end(nil)), .succeedRequest(.none, .init())) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -454,9 +675,15 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) - - state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest(HTTPParserError.invalidChunkSize, .close(nil)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) + + state.errorHappened(HTTPParserError.invalidChunkSize).assertFailRequest( + HTTPParserError.invalidChunkSize, + .close(nil) + ) XCTAssertEqual(state.requestCancelled(), .wait, "A cancellation that happens to late is ignored") } @@ -464,10 +691,16 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .internalServerError) - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -480,11 +713,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .internalServerError) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -498,13 +737,22 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .POST, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .stream) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: false)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: false) + ) let part1: ByteBuffer = .init(string: "foo") - XCTAssertEqual(state.requestStreamPartReceived(.byteBuffer(part1), promise: nil), .sendBodyPart(.byteBuffer(part1), nil)) + XCTAssertEqual( + state.requestStreamPartReceived(.byteBuffer(part1), promise: nil), + .sendBodyPart(.byteBuffer(part1), nil) + ) let responseHead = HTTPResponseHead(version: .http1_0, status: .ok) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -518,11 +766,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelRead(.body(body)), .wait) state.errorHappened(NIOSSLError.uncleanShutdown).assertFailRequest(NIOSSLError.uncleanShutdown, .close(nil)) @@ -534,7 +788,10 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) state.channelInactive().assertFailRequest(HTTPClientError.remoteConnectionClosed, .none) @@ -545,7 +802,10 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) state.errorHappened(ArbitraryError()).assertFailRequest(ArbitraryError(), .close(nil)) XCTAssertEqual(state.channelInactive(), .wait) @@ -555,17 +815,26 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["content-length": "30"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.read(), .read) XCTAssertEqual(state.channelRead(.body(body)), .wait) XCTAssertEqual(state.channelReadComplete(), .forwardResponseBodyParts([body])) XCTAssertEqual(state.errorHappened(NIOSSLError.uncleanShutdown), .wait) - state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest(HTTPParserError.invalidEOFState, .close(nil)) + state.errorHappened(HTTPParserError.invalidEOFState).assertFailRequest( + HTTPParserError.invalidEOFState, + .close(nil) + ) XCTAssertEqual(state.channelInactive(), .wait) } @@ -573,11 +842,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -594,11 +869,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -615,11 +896,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -635,11 +922,17 @@ class HTTPRequestStateMachineTests: XCTestCase { var state = HTTPRequestStateMachine(isChannelWritable: true) let requestHead = HTTPRequestHead(version: .http1_1, method: .GET, uri: "/") let metadata = RequestFramingMetadata(connectionClose: false, body: .fixedSize(0)) - XCTAssertEqual(state.startRequest(head: requestHead, metadata: metadata), .sendRequestHead(requestHead, sendEnd: true)) + XCTAssertEqual( + state.startRequest(head: requestHead, metadata: metadata), + .sendRequestHead(requestHead, sendEnd: true) + ) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["Content-Length": "50"]) let body = ByteBuffer(string: "foo bar") - XCTAssertEqual(state.channelRead(.head(responseHead)), .forwardResponseHead(responseHead, pauseRequestBodyStream: false)) + XCTAssertEqual( + state.channelRead(.head(responseHead)), + .forwardResponseHead(responseHead, pauseRequestBodyStream: false) + ) XCTAssertEqual(state.demandMoreResponseBodyParts(), .wait) XCTAssertEqual(state.channelReadComplete(), .wait) XCTAssertEqual(state.read(), .read) @@ -688,13 +981,19 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.resumeRequestBodyStream, .resumeRequestBodyStream): return true - case (.forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream)): + case ( + .forwardResponseHead(let lhsHead, let lhsPauseRequestBodyStream), + .forwardResponseHead(let rhsHead, let rhsPauseRequestBodyStream) + ): return lhsHead == rhsHead && lhsPauseRequestBodyStream == rhsPauseRequestBodyStream case (.forwardResponseBodyParts(let lhsData), .forwardResponseBodyParts(let rhsData)): return lhsData == rhsData - case (.succeedRequest(let lhsFinalAction, let lhsFinalBuffer), .succeedRequest(let rhsFinalAction, let rhsFinalBuffer)): + case ( + .succeedRequest(let lhsFinalAction, let lhsFinalBuffer), + .succeedRequest(let rhsFinalAction, let rhsFinalBuffer) + ): return lhsFinalAction == rhsFinalAction && lhsFinalBuffer == rhsFinalBuffer case (.failRequest(_, let lhsFinalAction), .failRequest(_, let rhsFinalAction)): @@ -706,10 +1005,16 @@ extension HTTPRequestStateMachine.Action: Equatable { case (.wait, .wait): return true - case (.failSendBodyPart(let lhsError as HTTPClientError, let lhsPromise), .failSendBodyPart(let rhsError as HTTPClientError, let rhsPromise)): + case ( + .failSendBodyPart(let lhsError as HTTPClientError, let lhsPromise), + .failSendBodyPart(let rhsError as HTTPClientError, let rhsPromise) + ): return lhsError == rhsError && lhsPromise?.futureResult == rhsPromise?.futureResult - case (.failSendStreamFinished(let lhsError as HTTPClientError, let lhsPromise), .failSendStreamFinished(let rhsError as HTTPClientError, let rhsPromise)): + case ( + .failSendStreamFinished(let lhsError as HTTPClientError, let lhsPromise), + .failSendStreamFinished(let rhsError as HTTPClientError, let rhsPromise) + ): return lhsError == rhsError && lhsPromise?.futureResult == rhsPromise?.futureResult default: @@ -719,7 +1024,10 @@ extension HTTPRequestStateMachine.Action: Equatable { } extension HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction: Equatable { - public static func == (lhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, rhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction) -> Bool { + public static func == ( + lhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction, + rhs: HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction + ) -> Bool { switch (lhs, rhs) { case (.close, close): return true @@ -737,7 +1045,10 @@ extension HTTPRequestStateMachine.Action.FinalSuccessfulRequestAction: Equatable } extension HTTPRequestStateMachine.Action.FinalFailedRequestAction: Equatable { - public static func == (lhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction, rhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction) -> Bool { + public static func == ( + lhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction, + rhs: HTTPRequestStateMachine.Action.FinalFailedRequestAction + ) -> Bool { switch (lhs, rhs) { case (.close(let lhsPromise), close(let rhsPromise)): return lhsPromise?.futureResult == rhsPromise?.futureResult @@ -759,7 +1070,11 @@ extension HTTPRequestStateMachine.Action { line: UInt = #line ) where Error: Swift.Error & Equatable { guard case .failRequest(let actualError, let actualFinalStreamAction) = self else { - return XCTFail("expected .failRequest(\(expectedError), \(expectedFinalStreamAction)) but got \(self)", file: file, line: line) + return XCTFail( + "expected .failRequest(\(expectedError), \(expectedFinalStreamAction)) but got \(self)", + file: file, + line: line + ) } if let actualError = actualError as? Error { XCTAssertEqual(actualError, expectedError, file: file, line: line) diff --git a/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift b/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift index e7cfed4d0..e9a0d46dc 100644 --- a/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift +++ b/Tests/AsyncHTTPClientTests/IdleTimeoutNoReuseTests.swift @@ -14,9 +14,6 @@ import AsyncHTTPClient import Atomics -#if canImport(Network) -import Network -#endif import Logging import NIOConcurrencyHelpers import NIOCore @@ -29,6 +26,10 @@ import NIOTestUtils import NIOTransportServices import XCTest +#if canImport(Network) +import Network +#endif + final class TestIdleTimeoutNoReuse: XCTestCaseHTTPClientTestsBaseClass { func testIdleTimeoutNoReuse() throws { var req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", method: .GET) diff --git a/Tests/AsyncHTTPClientTests/LRUCacheTests.swift b/Tests/AsyncHTTPClientTests/LRUCacheTests.swift index 6392bcebe..6173c34eb 100644 --- a/Tests/AsyncHTTPClientTests/LRUCacheTests.swift +++ b/Tests/AsyncHTTPClientTests/LRUCacheTests.swift @@ -12,9 +12,10 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import XCTest +@testable import AsyncHTTPClient + class LRUCacheTests: XCTestCase { func testBasicsWork() { var cache = LRUCache(capacity: 1) diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift index ca41a1e39..e49c67f19 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockConnectionPool.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOHTTP1 import NIOSSL +@testable import AsyncHTTPClient + /// A mock connection pool (not creating any actual connections) that is used to validate /// connection actions returned by the `HTTPConnectionPool.StateMachine`. struct MockConnectionPool { @@ -543,6 +544,7 @@ extension MockConnectionPool { idGenerator: .init(), maximumConcurrentHTTP1Connections: maxNumberOfConnections, retryConnectionEstablishment: true, + preferHTTP1: true, maximumConnectionUses: nil ) var connections = MockConnectionPool() @@ -553,7 +555,9 @@ extension MockConnectionPool { let request = HTTPConnectionPool.Request(mockRequest) let action = state.executeRequest(request) - guard case .scheduleRequestTimeout(request, on: let waitEL) = action.request, mockRequest.eventLoop === waitEL else { + guard case .scheduleRequestTimeout(request, on: let waitEL) = action.request, + mockRequest.eventLoop === waitEL + else { throw SetupError.expectedRequestToBeAddedToQueue } @@ -608,6 +612,7 @@ extension MockConnectionPool { idGenerator: .init(), maximumConcurrentHTTP1Connections: 8, retryConnectionEstablishment: true, + preferHTTP1: false, maximumConnectionUses: nil ) var connections = MockConnectionPool() @@ -619,7 +624,9 @@ extension MockConnectionPool { let request = HTTPConnectionPool.Request(mockRequest) let executeAction = state.executeRequest(request) - guard case .scheduleRequestTimeout(request, on: let waitEL) = executeAction.request, mockRequest.eventLoop === waitEL else { + guard case .scheduleRequestTimeout(request, on: let waitEL) = executeAction.request, + mockRequest.eventLoop === waitEL + else { throw SetupError.expectedRequestToBeAddedToQueue } @@ -632,17 +639,16 @@ extension MockConnectionPool { // 2. the connection becomes available - let newConnection = try connections.succeedConnectionCreationHTTP2(connectionID, maxConcurrentStreams: maxConcurrentStreams) + let newConnection = try connections.succeedConnectionCreationHTTP2( + connectionID, + maxConcurrentStreams: maxConcurrentStreams + ) let action = state.newHTTP2ConnectionCreated(newConnection, maxConcurrentStreams: maxConcurrentStreams) guard case .executeRequestsAndCancelTimeouts([request], newConnection) = action.request else { throw SetupError.expectedPreviouslyQueuedRequestToBeRunNow } - guard case .migration(createConnections: let create, closeConnections: [], scheduleTimeout: nil) = action.connection, create.isEmpty else { - throw SetupError.expectedNoConnectionAction - } - guard try queuer.get(request.id, request: request.__testOnly_wrapped_request()) === mockRequest else { throw SetupError.expectedPreviouslyQueuedRequestToBeRunNow } @@ -676,10 +682,12 @@ final class MockHTTPScheduableRequest: HTTPSchedulableRequest { let preferredEventLoop: EventLoop let requiredEventLoop: EventLoop? - init(eventLoop: EventLoop, - logger: Logger = Logger(label: "mock"), - connectionTimeout: TimeAmount = .seconds(60), - requiresEventLoopForChannel: Bool = false) { + init( + eventLoop: EventLoop, + logger: Logger = Logger(label: "mock"), + connectionTimeout: TimeAmount = .seconds(60), + requiresEventLoopForChannel: Bool = false + ) { self.logger = logger self.connectionDeadline = .now() + connectionTimeout @@ -694,7 +702,7 @@ final class MockHTTPScheduableRequest: HTTPSchedulableRequest { } var eventLoop: EventLoop { - return self.preferredEventLoop + self.preferredEventLoop } // MARK: HTTPSchedulableRequest diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift index aa0dc45eb..67f18cbb8 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockHTTPExecutableRequest.swift @@ -12,14 +12,16 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging +import NIOConcurrencyHelpers import NIOCore import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + final class MockHTTPExecutableRequest: HTTPExecutableRequest { - enum Event { + enum Event: Sendable { /// ``Event`` without associated values enum Kind: Hashable { case willExecuteRequest @@ -55,39 +57,49 @@ final class MockHTTPExecutableRequest: HTTPExecutableRequest { } } - var logger: Logging.Logger = Logger(label: "request") - var requestHead: NIOHTTP1.HTTPRequestHead - var requestFramingMetadata: RequestFramingMetadata - var requestOptions: RequestOptions = .forTests() + let logger: Logging.Logger = Logger(label: "request") + let requestHead: NIOHTTP1.HTTPRequestHead + let requestFramingMetadata: RequestFramingMetadata + let requestOptions: RequestOptions = .forTests() /// if true and ``HTTPExecutableRequest`` method is called without setting a corresponding callback on `self` e.g. /// If ``HTTPExecutableRequest\.willExecuteRequest(_:)`` is called but ``willExecuteRequestCallback`` is not set, /// ``XCTestFail(_:)`` will be called to fail the current test. - var raiseErrorIfUnimplementedMethodIsCalled: Bool = true - private var file: StaticString - private var line: UInt - - var willExecuteRequestCallback: ((HTTPRequestExecutor) -> Void)? - var requestHeadSentCallback: (() -> Void)? - var resumeRequestBodyStreamCallback: (() -> Void)? - var pauseRequestBodyStreamCallback: (() -> Void)? - var receiveResponseHeadCallback: ((HTTPResponseHead) -> Void)? - var receiveResponseBodyPartsCallback: ((CircularBuffer) -> Void)? - var succeedRequestCallback: ((CircularBuffer?) -> Void)? - var failCallback: ((Error) -> Void)? + let raiseErrorIfUnimplementedMethodIsCalled: Bool + private let file: StaticString + private let line: UInt + + let willExecuteRequestCallback: (@Sendable (HTTPRequestExecutor) -> Void)? = nil + let requestHeadSentCallback: (@Sendable () -> Void)? = nil + let resumeRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + let pauseRequestBodyStreamCallback: (@Sendable () -> Void)? = nil + let receiveResponseHeadCallback: (@Sendable (HTTPResponseHead) -> Void)? = nil + let receiveResponseBodyPartsCallback: (@Sendable (CircularBuffer) -> Void)? = nil + let succeedRequestCallback: (@Sendable (CircularBuffer?) -> Void)? = nil + let failCallback: (@Sendable (Error) -> Void)? = nil /// captures all ``HTTPExecutableRequest`` method calls in the order of occurrence, including arguments. /// If you are not interested in the arguments you can use `events.map(\.kind)` to get all events without arguments. - private(set) var events: [Event] = [] + private let _events = NIOLockedValueBox<[Event]>([]) + private(set) var events: [Event] { + get { + self._events.withLockedValue { $0 } + } + set { + self._events.withLockedValue { $0 = newValue } + } + } init( head: NIOHTTP1.HTTPRequestHead = .init(version: .http1_1, method: .GET, uri: "http://localhost/"), framingMetadata: RequestFramingMetadata = .init(connectionClose: false, body: .fixedSize(0)), + raiseErrorIfUnimplementedMethodIsCalled: Bool = true, file: StaticString = #file, line: UInt = #line ) { self.requestHead = head self.requestFramingMetadata = framingMetadata + self.raiseErrorIfUnimplementedMethodIsCalled = raiseErrorIfUnimplementedMethodIsCalled self.file = file self.line = line } diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift index b37ce8fa3..e5d9caa8e 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestExecutor.swift @@ -12,10 +12,11 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOConcurrencyHelpers import NIOCore +@testable import AsyncHTTPClient + // This is a MockRequestExecutor, that is synchronized on its EventLoop. final class MockRequestExecutor { enum Errors: Error { @@ -24,7 +25,7 @@ final class MockRequestExecutor { case unexpectedByteBuffer } - enum RequestParts: Equatable { + enum RequestParts: Equatable, Sendable { case body(IOData) case endOfStream @@ -47,7 +48,7 @@ final class MockRequestExecutor { } var requestBodyPartsCount: Int { - return self.blockingQueue.count + self.blockingQueue.count } let eventLoop: EventLoop @@ -57,10 +58,15 @@ final class MockRequestExecutor { private let responseBodyDemandLock = ConditionLock(value: false) private let cancellationLock = ConditionLock(value: false) - private var request: HTTPExecutableRequest? - private var _signaledDemandForRequestBody: Bool = false + private struct State: Sendable { + var request: HTTPExecutableRequest? + var _signaledDemandForRequestBody: Bool = false + } + + private let state: NIOLockedValueBox init(pauseRequestBodyPartStreamAfterASingleWrite: Bool = false, eventLoop: EventLoop) { + self.state = NIOLockedValueBox(State()) self.pauseRequestBodyPartStreamAfterASingleWrite = pauseRequestBodyPartStreamAfterASingleWrite self.eventLoop = eventLoop } @@ -76,13 +82,16 @@ final class MockRequestExecutor { } private func runRequest0(_ request: HTTPExecutableRequest) { - precondition(self.request == nil) - self.request = request + self.state.withLockedValue { + precondition($0.request == nil) + $0.request = request + } request.willExecuteRequest(self) request.requestHeadSent() } - func receiveRequestBody(deadline: NIODeadline = .now() + .seconds(5), _ verify: (ByteBuffer) throws -> Void) throws { + func receiveRequestBody(deadline: NIODeadline = .now() + .seconds(5), _ verify: (ByteBuffer) throws -> Void) throws + { enum ReceiveAction { case value(RequestParts) case future(EventLoopFuture) @@ -125,10 +134,16 @@ final class MockRequestExecutor { } private func pauseRequestBodyStream0() { - if self._signaledDemandForRequestBody == true { - self._signaledDemandForRequestBody = false - self.request!.pauseRequestBodyStream() + let request = self.state.withLockedValue { + if $0._signaledDemandForRequestBody == true { + $0._signaledDemandForRequestBody = false + return $0.request + } else { + return nil + } } + + request?.pauseRequestBodyStream() } func resumeRequestBodyStream() { @@ -142,10 +157,16 @@ final class MockRequestExecutor { } private func resumeRequestBodyStream0() { - if self._signaledDemandForRequestBody == false { - self._signaledDemandForRequestBody = true - self.request!.resumeRequestBodyStream() + let request = self.state.withLockedValue { + if $0._signaledDemandForRequestBody == false { + $0._signaledDemandForRequestBody = true + return $0.request + } else { + return nil + } } + + request?.resumeRequestBodyStream() } func resetResponseStreamDemandSignal() { @@ -155,10 +176,11 @@ final class MockRequestExecutor { func receiveResponseDemand(deadline: NIODeadline = .now() + .seconds(5)) throws { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.responseBodyDemandLock.lock( - whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) - ) + guard + self.responseBodyDemandLock.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) else { throw TimeoutError() } @@ -168,10 +190,11 @@ final class MockRequestExecutor { func receiveCancellation(deadline: NIODeadline = .now() + .seconds(5)) throws { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.cancellationLock.lock( - whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) - ) + guard + self.cancellationLock.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) else { throw TimeoutError() } @@ -200,11 +223,13 @@ extension MockRequestExecutor: HTTPRequestExecutor { case none } - let stateChange = { () -> WriteAction in + let stateChange = { @Sendable () -> WriteAction in var pause = false if self.blockingQueue.isEmpty && self.pauseRequestBodyPartStreamAfterASingleWrite && part.isBody { pause = true - self._signaledDemandForRequestBody = false + self.state.withLockedValue { + $0._signaledDemandForRequestBody = false + } } self.blockingQueue.append(.success(part)) @@ -265,8 +290,12 @@ extension MockRequestExecutor { internal func popFirst(deadline: NIODeadline) throws -> Element { let secondsUntilDeath = deadline - NIODeadline.now() - guard self.condition.lock(whenValue: true, - timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000)) else { + guard + self.condition.lock( + whenValue: true, + timeoutSeconds: .init(secondsUntilDeath.nanoseconds / 1_000_000_000) + ) + else { throw TimeoutError() } let first = self.buffer.removeFirst() @@ -275,3 +304,5 @@ extension MockRequestExecutor { } } } + +extension MockRequestExecutor.BlockingQueue: @unchecked Sendable where Element: Sendable {} diff --git a/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift b/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift index 520b51875..44e820444 100644 --- a/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift +++ b/Tests/AsyncHTTPClientTests/Mocks/MockRequestQueuer.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOCore import NIOHTTP1 +@testable import AsyncHTTPClient + /// A mock request queue (not creating any timers) that is used to validate /// request actions returned by the `HTTPConnectionPool.StateMachine`. struct MockRequestQueuer { diff --git a/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift b/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift index ff9e7f45d..63eaf649d 100644 --- a/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/NWWaitingHandlerTests.swift @@ -16,6 +16,7 @@ @testable import AsyncHTTPClient import Network import NIOCore +import NIOConcurrencyHelpers import NIOEmbedded import NIOSSL import NIOTransportServices @@ -23,21 +24,41 @@ import XCTest @available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 5.0, *) class NWWaitingHandlerTests: XCTestCase { - class MockRequester: HTTPConnectionRequester { - var waitingForConnectivityCalled = false - var connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID? - var transientError: NWError? + final class MockRequester: HTTPConnectionRequester { + private struct State: Sendable { + var waitingForConnectivityCalled = false + var connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID? + var transientError: NWError? + } + + private let state = NIOLockedValueBox(State()) + + var waitingForConnectivityCalled: Bool { + self.state.withLockedValue { $0.waitingForConnectivityCalled } + } + + var connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID? { + self.state.withLockedValue { $0.connectionID } + } + + var transientError: NWError? { + self.state.withLockedValue { + $0.transientError + } + } - func http1ConnectionCreated(_: AsyncHTTPClient.HTTP1Connection) {} + func http1ConnectionCreated(_: AsyncHTTPClient.HTTP1Connection.SendableView) {} - func http2ConnectionCreated(_: AsyncHTTPClient.HTTP2Connection, maximumStreams: Int) {} + func http2ConnectionCreated(_: AsyncHTTPClient.HTTP2Connection.SendableView, maximumStreams: Int) {} func failedToCreateHTTPConnection(_: AsyncHTTPClient.HTTPConnectionPool.Connection.ID, error: Error) {} func waitingForConnectivity(_ connectionID: AsyncHTTPClient.HTTPConnectionPool.Connection.ID, error: Error) { - self.waitingForConnectivityCalled = true - self.connectionID = connectionID - self.transientError = error as? NWError + self.state.withLockedValue { + $0.waitingForConnectivityCalled = true + $0.connectionID = connectionID + $0.transientError = error as? NWError + } } } @@ -47,9 +68,14 @@ class NWWaitingHandlerTests: XCTestCase { let waitingEventHandler = NWWaitingHandler(requester: requester, connectionID: connectionID) let embedded = EmbeddedChannel(handlers: [waitingEventHandler]) - embedded.pipeline.fireUserInboundEventTriggered(NIOTSNetworkEvents.WaitingForConnectivity(transientError: .dns(1))) + embedded.pipeline.fireUserInboundEventTriggered( + NIOTSNetworkEvents.WaitingForConnectivity(transientError: .dns(1)) + ) - XCTAssertTrue(requester.waitingForConnectivityCalled, "Expected the handler to invoke .waitingForConnectivity on the requester") + XCTAssertTrue( + requester.waitingForConnectivityCalled, + "Expected the handler to invoke .waitingForConnectivity on the requester" + ) XCTAssertEqual(requester.connectionID, connectionID, "Expected the handler to pass connectionID to requester") XCTAssertEqual(requester.transientError, NWError.dns(1)) } @@ -60,7 +86,10 @@ class NWWaitingHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [waitingEventHandler]) embedded.pipeline.fireUserInboundEventTriggered(NIOTSNetworkEvents.BetterPathAvailable()) - XCTAssertFalse(requester.waitingForConnectivityCalled, "Should not call .waitingForConnectivity on unrelated events") + XCTAssertFalse( + requester.waitingForConnectivityCalled, + "Should not call .waitingForConnectivity on unrelated events" + ) } func testWaitingHandlerPassesTheEventDownTheContext() { diff --git a/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift b/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift index 41285d5c5..026a45d4c 100644 --- a/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift +++ b/Tests/AsyncHTTPClientTests/NoBytesSentOverBodyLimitTests.swift @@ -14,9 +14,6 @@ import AsyncHTTPClient import Atomics -#if canImport(Network) -import Network -#endif import Logging import NIOConcurrencyHelpers import NIOCore @@ -29,6 +26,10 @@ import NIOTestUtils import NIOTransportServices import XCTest +#if canImport(Network) +import Network +#endif + final class NoBytesSentOverBodyLimitTests: XCTestCaseHTTPClientTestsBaseClass { func testNoBytesSentOverBodyLimit() throws { let server = NIOHTTP1TestServer(group: self.serverGroup) @@ -40,7 +41,7 @@ final class NoBytesSentOverBodyLimitTests: XCTestCaseHTTPClientTestsBaseClass { let request = try Request( url: "http://localhost:\(server.serverPort)", - body: .stream(length: 1) { streamWriter in + body: .stream(contentLength: 1) { streamWriter in streamWriter.write(.byteBuffer(ByteBuffer(string: tooLong))) } ) diff --git a/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift b/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift index fd8e45273..35a09c421 100644 --- a/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift +++ b/Tests/AsyncHTTPClientTests/RacePoolIdleConnectionsAndGetTests.swift @@ -14,9 +14,6 @@ import AsyncHTTPClient import Atomics -#if canImport(Network) -import Network -#endif import Logging import NIOConcurrencyHelpers import NIOCore @@ -29,10 +26,16 @@ import NIOTestUtils import NIOTransportServices import XCTest +#if canImport(Network) +import Network +#endif + final class RacePoolIdleConnectionsAndGetTests: XCTestCaseHTTPClientTestsBaseClass { func testRacePoolIdleConnectionsAndGet() { - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: .init(connectionPool: .init(idleTimeout: .milliseconds(10)))) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: .init(connectionPool: .init(idleTimeout: .milliseconds(10))) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) } diff --git a/Tests/AsyncHTTPClientTests/RequestBagTests.swift b/Tests/AsyncHTTPClientTests/RequestBagTests.swift index 610e429f5..2b0c2f6e4 100644 --- a/Tests/AsyncHTTPClientTests/RequestBagTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestBagTests.swift @@ -12,38 +12,54 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient +import Atomics import Logging +import NIOConcurrencyHelpers import NIOCore import NIOEmbedded import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + final class RequestBagTests: XCTestCase { func testWriteBackpressureWorks() { let embeddedEventLoop = EmbeddedEventLoop() defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } let logger = Logger(label: "test") - var writtenBytes = 0 - var writes = 0 + struct TestState { + var writtenBytes: Int = 0 + var writes: Int = 0 + var streamIsAllowedToWrite: Bool = false + } + + let testState = NIOLockedValueBox(TestState()) + let bytesToSent = (3000...10000).randomElement()! let expectedWrites = bytesToSent / 100 + ((bytesToSent % 100 > 0) ? 1 : 0) - var streamIsAllowedToWrite = false let writeDonePromise = embeddedEventLoop.makePromise(of: Void.self) - let requestBody: HTTPClient.Body = .stream(length: bytesToSent) { writer -> EventLoopFuture in - func write(donePromise: EventLoopPromise) { - XCTAssertTrue(streamIsAllowedToWrite) - guard writtenBytes < bytesToSent else { - return donePromise.succeed(()) + let requestBody: HTTPClient.Body = .stream(contentLength: Int64(bytesToSent)) { + writer -> EventLoopFuture in + @Sendable func write(donePromise: EventLoopPromise) { + let futureWrite: EventLoopFuture? = testState.withLockedValue { state in + XCTAssertTrue(state.streamIsAllowedToWrite) + guard state.writtenBytes < bytesToSent else { + donePromise.succeed(()) + return nil + } + let byteCount = min(bytesToSent - state.writtenBytes, 100) + let buffer = ByteBuffer(bytes: [UInt8](repeating: 1, count: byteCount)) + state.writes += 1 + return writer.write(.byteBuffer(buffer)) } - let byteCount = min(bytesToSent - writtenBytes, 100) - let buffer = ByteBuffer(bytes: [UInt8](repeating: 1, count: byteCount)) - writes += 1 - writer.write(.byteBuffer(buffer)).whenSuccess { _ in - writtenBytes += 100 + + futureWrite?.whenSuccess { _ in + testState.withLockedValue { state in + state.writtenBytes += 100 + } write(donePromise: donePromise) } } @@ -54,20 +70,24 @@ final class RequestBagTests: XCTestCase { } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody)) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.task.eventLoop === embeddedEventLoop) @@ -81,46 +101,57 @@ final class RequestBagTests: XCTestCase { executor.runRequest(bag) XCTAssertEqual(delegate.hitDidSendRequestHead, 1) - streamIsAllowedToWrite = true + testState.withLockedValue { $0.streamIsAllowedToWrite = true } bag.resumeRequestBodyStream() - streamIsAllowedToWrite = false + testState.withLockedValue { $0.streamIsAllowedToWrite = false } // after starting the body stream we should have received two writes var receivedBytes = 0 for i in 0.. EventLoopFuture in + let requestBody: HTTPClient.Body = .stream(contentLength: 12) { writer -> EventLoopFuture in writer.write(.byteBuffer(ByteBuffer(bytes: 0...3))).flatMap { _ -> EventLoopFuture in embeddedEventLoop.makeFailedFuture(TestError()) @@ -161,20 +192,24 @@ final class RequestBagTests: XCTestCase { } var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody)) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request(url: "https://swift.org", method: .POST, body: requestBody) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.task.eventLoop === embeddedEventLoop) @@ -207,15 +242,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) @@ -240,15 +277,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) @@ -279,15 +318,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) @@ -320,15 +361,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } XCTAssert(bag.eventLoop === embeddedEventLoop) @@ -361,15 +404,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let queuer = MockTaskQueuer() @@ -395,15 +440,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -423,23 +470,27 @@ final class RequestBagTests: XCTestCase { let logger = Logger(label: "test") var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "https://swift.org", - body: .bytes([1, 2, 3, 4, 5]) - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + body: .bytes([1, 2, 3, 4, 5]) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -463,11 +514,13 @@ final class RequestBagTests: XCTestCase { var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "https://swift.org", - method: .POST, - body: .byteBuffer(.init(bytes: [1])) - )) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + method: .POST, + body: .byteBuffer(.init(bytes: [1])) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } struct MyError: Error, Equatable {} @@ -493,15 +546,17 @@ final class RequestBagTests: XCTestCase { } let delegate = Delegate(didFinishPromise: embeddedEventLoop.makePromise()) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -530,42 +585,46 @@ final class RequestBagTests: XCTestCase { var maybeRequest: HTTPClient.Request? let writeSecondPartPromise = embeddedEventLoop.makePromise(of: Void.self) - - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request( - url: "https://swift.org", - method: .POST, - headers: ["content-length": "12"], - body: .stream(length: 12) { writer -> EventLoopFuture in - var firstWriteSuccess = false - return writer.write(.byteBuffer(.init(bytes: 0...3))).flatMap { _ in - firstWriteSuccess = true - - return writeSecondPartPromise.futureResult - }.flatMap { - return writer.write(.byteBuffer(.init(bytes: 4...7))) - }.always { result in - XCTAssertTrue(firstWriteSuccess) - - guard case .failure(let error) = result else { - return XCTFail("Expected the second write to fail") + let firstWriteSuccess: NIOLockedValueBox = .init(false) + + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request( + url: "https://swift.org", + method: .POST, + headers: ["content-length": "12"], + body: .stream(contentLength: 12) { writer -> EventLoopFuture in + writer.write(.byteBuffer(.init(bytes: 0...3))).flatMap { _ in + firstWriteSuccess.withLockedValue { $0 = true } + + return writeSecondPartPromise.futureResult + }.flatMap { + writer.write(.byteBuffer(.init(bytes: 4...7))) + }.always { result in + XCTAssertTrue(firstWriteSuccess.withLockedValue { $0 }) + + guard case .failure(let error) = result else { + return XCTFail("Expected the second write to fail") + } + XCTAssertEqual(error as? HTTPClientError, .requestStreamCancelled) } - XCTAssertEqual(error as? HTTPClientError, .requestStreamCancelled) } - } - )) + ) + ) guard let request = maybeRequest else { return XCTFail("Expected to have a request") } let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -601,15 +660,17 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: nil, - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( + request: request, + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: nil, + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) @@ -657,36 +718,49 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? var redirectTriggered = false - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: .init( + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( request: request, - redirectState: RedirectState( - .follow(max: 5, allowCycles: false), - initialURL: request.url.absoluteString - )!, - execute: { request, _ in - XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") - XCTAssertFalse(redirectTriggered) - - let task = HTTPClient.Task(eventLoop: embeddedEventLoop, logger: logger) - task.promise.fail(HTTPClientError.cancelled) - redirectTriggered = true - return task - } - ), - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task( + eventLoop: embeddedEventLoop, + logger: logger + ) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) XCTAssertFalse(executor.signalledDemandForResponseBody) - bag.receiveResponseHead(.init(version: .http1_1, status: .permanentRedirect, headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"])) + XCTAssertTrue(delegate.history.isEmpty) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] + ) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) XCTAssertNil(delegate.backpressurePromise) XCTAssertTrue(executor.signalledDemandForResponseBody) executor.resetResponseStreamDemandSignal() @@ -732,36 +806,49 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? var redirectTriggered = false - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: .init( + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( request: request, - redirectState: RedirectState( - .follow(max: 5, allowCycles: false), - initialURL: request.url.absoluteString - )!, - execute: { request, _ in - XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") - XCTAssertFalse(redirectTriggered) - - let task = HTTPClient.Task(eventLoop: embeddedEventLoop, logger: logger) - task.promise.fail(HTTPClientError.cancelled) - redirectTriggered = true - return task - } - ), - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task( + eventLoop: embeddedEventLoop, + logger: logger + ) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) XCTAssertFalse(executor.signalledDemandForResponseBody) - bag.receiveResponseHead(.init(version: .http1_1, status: .permanentRedirect, headers: ["content-length": "\(4 * 1024)", "location": "https://swift.org/sswg"])) + XCTAssertTrue(delegate.history.isEmpty) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(4 * 1024)", "location": "https://swift.org/sswg"] + ) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) XCTAssertNil(delegate.backpressurePromise) XCTAssertFalse(executor.signalledDemandForResponseBody) XCTAssertTrue(executor.isCancelled) @@ -781,36 +868,49 @@ final class RequestBagTests: XCTestCase { let delegate = UploadCountingDelegate(eventLoop: embeddedEventLoop) var maybeRequestBag: RequestBag? var redirectTriggered = false - XCTAssertNoThrow(maybeRequestBag = try RequestBag( - request: request, - eventLoopPreference: .delegate(on: embeddedEventLoop), - task: .init(eventLoop: embeddedEventLoop, logger: logger), - redirectHandler: .init( + XCTAssertNoThrow( + maybeRequestBag = try RequestBag( request: request, - redirectState: RedirectState( - .follow(max: 5, allowCycles: false), - initialURL: request.url.absoluteString - )!, - execute: { request, _ in - XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") - XCTAssertFalse(redirectTriggered) - - let task = HTTPClient.Task(eventLoop: embeddedEventLoop, logger: logger) - task.promise.fail(HTTPClientError.cancelled) - redirectTriggered = true - return task - } - ), - connectionDeadline: .now() + .seconds(30), - requestOptions: .forTests(), - delegate: delegate - )) + eventLoopPreference: .delegate(on: embeddedEventLoop), + task: .init(eventLoop: embeddedEventLoop, logger: logger), + redirectHandler: .init( + request: request, + redirectState: RedirectState( + .follow(max: 5, allowCycles: false), + initialURL: request.url.absoluteString + )!, + execute: { request, _ in + XCTAssertEqual(request.url.absoluteString, "https://swift.org/sswg") + XCTAssertFalse(redirectTriggered) + + let task = HTTPClient.Task( + eventLoop: embeddedEventLoop, + logger: logger + ) + task.promise.fail(HTTPClientError.cancelled) + redirectTriggered = true + return task + } + ), + connectionDeadline: .now() + .seconds(30), + requestOptions: .forTests(), + delegate: delegate + ) + ) guard let bag = maybeRequestBag else { return XCTFail("Expected to be able to create a request bag.") } let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) executor.runRequest(bag) XCTAssertFalse(executor.signalledDemandForResponseBody) - bag.receiveResponseHead(.init(version: .http1_1, status: .permanentRedirect, headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"])) + XCTAssertTrue(delegate.history.isEmpty) + let responseHead = HTTPResponseHead( + version: .http1_1, + status: .permanentRedirect, + headers: ["content-length": "\(3 * 1024)", "location": "https://swift.org/sswg"] + ) + bag.receiveResponseHead(responseHead) + XCTAssertEqual(delegate.history.map(\.request.url), [request.url]) + XCTAssertEqual(delegate.history.map(\.response), [responseHead]) XCTAssertNil(delegate.backpressurePromise) XCTAssertTrue(executor.signalledDemandForResponseBody) executor.resetResponseStreamDemandSignal() @@ -839,7 +939,7 @@ final class RequestBagTests: XCTestCase { } func testWeDontLeakTheRequestIfTheRequestWriterWasCapturedByAPromise() { - final class LeakDetector {} + final class LeakDetector: Sendable {} let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) defer { XCTAssertNoThrow(try group.syncShutdownGracefully()) } @@ -854,7 +954,9 @@ final class RequestBagTests: XCTestCase { do { var maybeRequest: HTTPClient.Request? - XCTAssertNoThrow(maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/", method: .POST)) + XCTAssertNoThrow( + maybeRequest = try HTTPClient.Request(url: "http://localhost:\(httpBin.port)/", method: .POST) + ) guard var request = maybeRequest else { return XCTFail("Expected to have a request here") } let writerPromise = group.any().makePromise(of: HTTPClient.Body.StreamWriter.self) @@ -898,71 +1000,100 @@ extension HTTPClient.Task { } } -class UploadCountingDelegate: HTTPClientResponseDelegate { +final class UploadCountingDelegate: HTTPClientResponseDelegate { typealias Response = Void let eventLoop: EventLoop - private(set) var hitDidSendRequestHead = 0 - private(set) var hitDidSendRequestPart = 0 - private(set) var hitDidSendRequest = 0 - private(set) var hitDidReceiveResponse = 0 - private(set) var hitDidReceiveBodyPart = 0 - private(set) var hitDidReceiveError = 0 + struct State: Sendable { + var hitDidSendRequestHead = 0 + var hitDidSendRequestPart = 0 + var hitDidSendRequest = 0 + var hitDidReceiveResponse = 0 + var hitDidReceiveBodyPart = 0 + var hitDidReceiveError = 0 + + var history: [(request: HTTPClient.Request, response: HTTPResponseHead)] = [] + var receivedHead: HTTPResponseHead? + var lastBodyPart: ByteBuffer? + var backpressurePromise: EventLoopPromise? + var lastError: Error? + } + + private let state: NIOLoopBoundBox - private(set) var receivedHead: HTTPResponseHead? - private(set) var lastBodyPart: ByteBuffer? - private(set) var backpressurePromise: EventLoopPromise? - private(set) var lastError: Error? + var hitDidSendRequestHead: Int { self.state.value.hitDidSendRequestHead } + var hitDidSendRequestPart: Int { self.state.value.hitDidSendRequestPart } + var hitDidSendRequest: Int { self.state.value.hitDidSendRequest } + var hitDidReceiveResponse: Int { self.state.value.hitDidReceiveResponse } + var hitDidReceiveBodyPart: Int { self.state.value.hitDidReceiveBodyPart } + var hitDidReceiveError: Int { self.state.value.hitDidReceiveError } + + var history: [(request: HTTPClient.Request, response: HTTPResponseHead)] { + self.state.value.history + } + var receivedHead: HTTPResponseHead? { self.state.value.receivedHead } + var lastBodyPart: ByteBuffer? { self.state.value.lastBodyPart } + var backpressurePromise: EventLoopPromise? { self.state.value.backpressurePromise } + var lastError: Error? { self.state.value.lastError } init(eventLoop: EventLoop) { self.eventLoop = eventLoop + self.state = .makeBoxSendingValue(State(), eventLoop: eventLoop) } func didSendRequestHead(task: HTTPClient.Task, _ head: HTTPRequestHead) { - self.hitDidSendRequestHead += 1 + self.state.value.hitDidSendRequestHead += 1 } func didSendRequestPart(task: HTTPClient.Task, _ part: IOData) { - self.hitDidSendRequestPart += 1 + self.state.value.hitDidSendRequestPart += 1 } func didSendRequest(task: HTTPClient.Task) { - self.hitDidSendRequest += 1 + self.state.value.hitDidSendRequest += 1 + } + + func didVisitURL(task: HTTPClient.Task, _ request: HTTPClient.Request, _ head: HTTPResponseHead) { + self.state.value.history.append((request, head)) } func didReceiveHead(task: HTTPClient.Task, _ head: HTTPResponseHead) -> EventLoopFuture { - self.receivedHead = head + self.state.value.receivedHead = head return self.createBackpressurePromise() } func didReceiveBodyPart(task: HTTPClient.Task, _ buffer: ByteBuffer) -> EventLoopFuture { - assert(self.backpressurePromise == nil) - self.hitDidReceiveBodyPart += 1 - self.lastBodyPart = buffer + assert(self.state.value.backpressurePromise == nil) + self.state.value.hitDidReceiveBodyPart += 1 + self.state.value.lastBodyPart = buffer return self.createBackpressurePromise() } func didFinishRequest(task: HTTPClient.Task) throws { - self.hitDidReceiveResponse += 1 + self.state.value.hitDidReceiveResponse += 1 } func didReceiveError(task: HTTPClient.Task, _ error: Error) { - self.hitDidReceiveError += 1 - self.lastError = error + self.state.value.hitDidReceiveError += 1 + self.state.value.lastError = error } private func createBackpressurePromise() -> EventLoopFuture { - assert(self.backpressurePromise == nil) - self.backpressurePromise = self.eventLoop.makePromise(of: Void.self) - return self.backpressurePromise!.futureResult.always { _ in - self.backpressurePromise = nil + assert(self.state.value.backpressurePromise == nil) + self.state.value.backpressurePromise = self.eventLoop.makePromise(of: Void.self) + return self.state.value.backpressurePromise!.futureResult.always { _ in + self.state.value.backpressurePromise = nil } } } final class MockTaskQueuer: HTTPRequestScheduler { - private(set) var hitCancelCount = 0 + private let _hitCancelCount = ManagedAtomic(0) + + var hitCancelCount: Int { + self._hitCancelCount.load(ordering: .sequentiallyConsistent) + } let onCancelRequest: (@Sendable (HTTPSchedulableRequest) -> Void)? @@ -971,7 +1102,7 @@ final class MockTaskQueuer: HTTPRequestScheduler { } func cancelRequest(_ request: HTTPSchedulableRequest) { - self.hitCancelCount += 1 + self._hitCancelCount.wrappingIncrement(ordering: .sequentiallyConsistent) self.onCancelRequest?(request) } } diff --git a/Tests/AsyncHTTPClientTests/RequestValidationTests.swift b/Tests/AsyncHTTPClientTests/RequestValidationTests.swift index c50d3afd1..ea5a6bd66 100644 --- a/Tests/AsyncHTTPClientTests/RequestValidationTests.swift +++ b/Tests/AsyncHTTPClientTests/RequestValidationTests.swift @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + class RequestValidationTests: XCTestCase { func testContentLengthHeaderIsRemovedFromGETIfNoBody() { var headers = HTTPHeaders([("Content-Length", "0")]) @@ -29,13 +30,17 @@ class RequestValidationTests: XCTestCase { func testContentLengthHeaderIsAddedToPOSTAndPUTWithNoBody() { var putHeaders = HTTPHeaders() var putMetadata: RequestFramingMetadata? - XCTAssertNoThrow(putMetadata = try putHeaders.validateAndSetTransportFraming(method: .PUT, bodyLength: .known(0))) + XCTAssertNoThrow( + putMetadata = try putHeaders.validateAndSetTransportFraming(method: .PUT, bodyLength: .known(0)) + ) XCTAssertEqual(putHeaders.first(name: "Content-Length"), "0") XCTAssertEqual(putMetadata?.body, .fixedSize(0)) var postHeaders = HTTPHeaders() var postMetadata: RequestFramingMetadata? - XCTAssertNoThrow(postMetadata = try postHeaders.validateAndSetTransportFraming(method: .POST, bodyLength: .known(0))) + XCTAssertNoThrow( + postMetadata = try postHeaders.validateAndSetTransportFraming(method: .POST, bodyLength: .known(0)) + ) XCTAssertEqual(postHeaders.first(name: "Content-Length"), "0") XCTAssertEqual(postMetadata?.body, .fixedSize(0)) } @@ -90,7 +95,7 @@ class RequestValidationTests: XCTestCase { func testMetadataDetectConnectionClose() { var headers = HTTPHeaders([ - ("Connection", "close"), + ("Connection", "close") ]) var metadata: RequestFramingMetadata? XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: .GET, bodyLength: .known(0))) @@ -114,7 +119,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT, .TRACE] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -123,7 +130,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertEqual(headers["content-length"].first, "0") XCTAssertFalse(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -139,7 +148,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -149,7 +160,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown)) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .stream) @@ -159,7 +172,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -169,7 +184,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init() var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown)) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .unknown) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .stream) @@ -184,7 +201,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT, .TRACE] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -193,7 +212,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertEqual(headers["content-length"].first, "0") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -208,7 +229,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -217,7 +240,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init([("Content-Length", "1")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(1)) + ) XCTAssertEqual(headers["content-length"].first, "1") XCTAssertTrue(headers["transfer-encoding"].isEmpty) XCTAssertEqual(metadata?.body, .fixedSize(1)) @@ -232,7 +257,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.GET, .HEAD, .DELETE, .CONNECT, .TRACE] { var headers: HTTPHeaders = .init([("Transfer-Encoding", "chunked")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertTrue(headers["content-length"].isEmpty) XCTAssertFalse(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -241,7 +268,9 @@ class RequestValidationTests: XCTestCase { for method: HTTPMethod in [.POST, .PUT] { var headers: HTTPHeaders = .init([("Transfer-Encoding", "chunked")]) var metadata: RequestFramingMetadata? - XCTAssertNoThrow(metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0))) + XCTAssertNoThrow( + metadata = try headers.validateAndSetTransportFraming(method: method, bodyLength: .known(0)) + ) XCTAssertEqual(headers["content-length"].first, "0") XCTAssertFalse(headers["transfer-encoding"].contains("chunked")) XCTAssertEqual(metadata?.body, .fixedSize(0)) @@ -337,21 +366,27 @@ class RequestValidationTests: XCTestCase { func testTransferEncodingsAreOverwrittenIfBodyLengthIsFixed() { var headers: HTTPHeaders = [ - "Transfer-Encoding": "gzip, chunked", + "Transfer-Encoding": "gzip, chunked" ] XCTAssertNoThrow(try headers.validateAndSetTransportFraming(method: .POST, bodyLength: .known(1))) - XCTAssertEqual(headers, [ - "Content-Length": "1", - ]) + XCTAssertEqual( + headers, + [ + "Content-Length": "1" + ] + ) } func testTransferEncodingsAreOverwrittenIfBodyLengthIsDynamic() { var headers: HTTPHeaders = [ - "Transfer-Encoding": "gzip, chunked", + "Transfer-Encoding": "gzip, chunked" ] XCTAssertNoThrow(try headers.validateAndSetTransportFraming(method: .POST, bodyLength: .unknown)) - XCTAssertEqual(headers, [ - "Transfer-Encoding": "chunked", - ]) + XCTAssertEqual( + headers, + [ + "Transfer-Encoding": "chunked" + ] + ) } } diff --git a/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem b/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem index f6314d47a..f16590cde 100644 --- a/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem +++ b/Tests/AsyncHTTPClientTests/Resources/example.com.cert.pem @@ -1,12 +1,12 @@ -----BEGIN CERTIFICATE----- -MIIBxDCCAUmgAwIBAgIVAPY31L1kyEnjO1E4inpE7+SYRO9mMAoGCCqGSM49BAMD -MCoxFDASBgNVBAoMC1NlbGYgU2lnbmVkMRIwEAYDVQQDDAlsb2NhbGhvc3QwHhcN -MjQwMzI4MjI0MDUyWhcNMjUwMzI4MjI0MDUyWjAqMRQwEgYDVQQKDAtTZWxmIFNp -Z25lZDESMBAGA1UEAwwJbG9jYWxob3N0MHYwEAYHKoZIzj0CAQYFK4EEACIDYgAE -o2i+uiLtMu0Jzsk3oEUnfoM9n44/aV9UeOXxyDs57i2E13HrJeWIXACetybkB+Q8 -Poab6ohbskTwrS7WN3tFgoGdRBCKQow/rTECdezR/fdz2cGADaBN+CNMuFSnFSr5 -oy8wLTAWBgNVHREEDzANggtleGFtcGxlLmNvbTATBgNVHSUEDDAKBggrBgEFBQcD -ATAKBggqhkjOPQQDAwNpADBmAjEAwF5OlUBOloDTIAxgaSSvHBMSVOE1rY5hUlkT -kQ+dQFeUe3Fn+Er5ohvkt+qVOQ5yAjEAt9s5b/Iz+JmWxKKUyExHob6QHEuuHmJy -AKdrn20Ply60bb8qxGYHhwhoyV2MZYVV +MIIBwTCCAUigAwIBAgIUX7f9BABxGdAqG5EvLpQScFt9lOkwCgYIKoZIzj0EAwMw +KjEUMBIGA1UECgwLU2VsZiBTaWduZWQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0y +NTA0MDExNDMwMTFaFw0yNjA0MDExNDMwMTFaMCoxFDASBgNVBAoMC1NlbGYgU2ln +bmVkMRIwEAYDVQQDDAlsb2NhbGhvc3QwdjAQBgcqhkjOPQIBBgUrgQQAIgNiAAQW +szfO5HCWIWgKUqyXUU0pFpYgaq01RRL69XZz1CkV6XTrxMfIvvwez2886EQDL8QX +i5NpKg3qvPgWuDjVHaj4WEJe5XMNqcujxcTufBlmaQ6o4vtoK7CIHDIDldF/HRij +LzAtMBYGA1UdEQQPMA2CC2V4YW1wbGUuY29tMBMGA1UdJQQMMAoGCCsGAQUFBwMB +MAoGCCqGSM49BAMDA2cAMGQCMBJ8Dxg0qX2bEZ3r6dI3UCGAUYxJDVk+XhiIY1Fm +5FJeQqhaVayCRPrPXXGZUJGY/wIwXej70FwkxHKLq+XxfHTC5CzmoOK469C9Rk9Y +ucddXM83ebFxVNgRCWetH9tDdXJ9 -----END CERTIFICATE----- \ No newline at end of file diff --git a/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem b/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem index 7cf27cc35..3ad9ce79e 100644 --- a/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem +++ b/Tests/AsyncHTTPClientTests/Resources/example.com.private-key.pem @@ -1,6 +1,6 @@ -----BEGIN PRIVATE KEY----- -MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDhC5OSjPQeYRm4irIH -z4EyM/NbJsX39SlI6J4/q0Syt0BwojgJKhCWfeveanbIjbWhZANiAASjaL66Iu0y -7QnOyTegRSd+gz2fjj9pX1R45fHIOznuLYTXcesl5YhcAJ63JuQH5Dw+hpvqiFuy -RPCtLtY3e0WCgZ1EEIpCjD+tMQJ17NH993PZwYANoE34I0y4VKcVKvk= +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDD9v51MTOcgFIbiHbok +U+QOubosGF1u1q+D3fEUb1U2cgjCofKmPHekXTz0xu9MJi2hZANiAAQWszfO5HCW +IWgKUqyXUU0pFpYgaq01RRL69XZz1CkV6XTrxMfIvvwez2886EQDL8QXi5NpKg3q +vPgWuDjVHaj4WEJe5XMNqcujxcTufBlmaQ6o4vtoK7CIHDIDldF/HRg= -----END PRIVATE KEY----- \ No newline at end of file diff --git a/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift b/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift index 0af5c7243..5fd1d6720 100644 --- a/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift +++ b/Tests/AsyncHTTPClientTests/ResponseDelayGetTests.swift @@ -14,9 +14,6 @@ import AsyncHTTPClient import Atomics -#if canImport(Network) -import Network -#endif import Logging import NIOConcurrencyHelpers import NIOCore @@ -29,15 +26,21 @@ import NIOTestUtils import NIOTransportServices import XCTest +#if canImport(Network) +import Network +#endif + final class ResponseDelayGetTests: XCTestCaseHTTPClientTestsBaseClass { func testResponseDelayGet() throws { - let req = try HTTPClient.Request(url: self.defaultHTTPBinURLPrefix + "get", - method: .GET, - headers: ["X-internal-delay": "2000"], - body: nil) + let req = try HTTPClient.Request( + url: self.defaultHTTPBinURLPrefix + "get", + method: .GET, + headers: ["X-internal-delay": "2000"], + body: nil + ) let start = NIODeadline.now() let response = try self.defaultClient.execute(request: req).wait() - XCTAssertGreaterThanOrEqual(.now() - start, .milliseconds(1_900 /* 1.9 seconds */ )) + XCTAssertGreaterThanOrEqual(.now() - start, .milliseconds(1_900)) XCTAssertEqual(response.status, .ok) } } diff --git a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift index 066a631a5..2352c6c1c 100644 --- a/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/SOCKSEventsHandlerTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOSOCKS import XCTest +@testable import AsyncHTTPClient + class SOCKSEventsHandlerTests: XCTestCase { func testHandlerHappyPath() { let socksEventsHandler = SOCKSEventsHandler(deadline: .now() + .seconds(10)) @@ -37,7 +38,7 @@ class SOCKSEventsHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [socksEventsHandler]) XCTAssertNotNil(socksEventsHandler.socksEstablishedFuture) - XCTAssertNoThrow(try embedded.pipeline.removeHandler(socksEventsHandler).wait()) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.removeHandler(socksEventsHandler).wait()) XCTAssertThrowsError(try XCTUnwrap(socksEventsHandler.socksEstablishedFuture).wait()) } diff --git a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift index d888769b4..50d26b278 100644 --- a/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift +++ b/Tests/AsyncHTTPClientTests/SOCKSTestUtils.swift @@ -40,25 +40,40 @@ class MockSOCKSServer { self.channel.localAddress!.port! } - init(expectedURL: String, expectedResponse: String, misbehave: Bool = false, file: String = #filePath, line: UInt = #line) throws { + init( + expectedURL: String, + expectedResponse: String, + misbehave: Bool = false, + file: String = #filePath, + line: UInt = #line + ) throws { let elg = MultiThreadedEventLoopGroup(numberOfThreads: 1) let bootstrap: ServerBootstrap if misbehave { bootstrap = ServerBootstrap(group: elg) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelInitializer { channel in - channel.pipeline.addHandler(TestSOCKSBadServerHandler()) + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(TestSOCKSBadServerHandler()) + } } } else { bootstrap = ServerBootstrap(group: elg) .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) .childChannelInitializer { channel in - let handshakeHandler = SOCKSServerHandshakeHandler() - return channel.pipeline.addHandlers([ - handshakeHandler, - SOCKSTestHandler(handshakeHandler: handshakeHandler), - TestHTTPServer(expectedURL: expectedURL, expectedResponse: expectedResponse, file: file, line: line), - ]) + channel.eventLoop.makeCompletedFuture { + let handshakeHandler = SOCKSServerHandshakeHandler() + try channel.pipeline.syncOperations.addHandlers([ + handshakeHandler, + SOCKSTestHandler(handshakeHandler: handshakeHandler), + TestHTTPServer( + expectedURL: expectedURL, + expectedResponse: expectedResponse, + file: file, + line: line + ), + ]) + } } } self.channel = try bootstrap.bind(host: "localhost", port: 0).wait() @@ -86,19 +101,34 @@ class SOCKSTestHandler: ChannelInboundHandler, RemovableChannelHandler { let message = self.unwrapInboundIn(data) switch message { case .greeting: - context.writeAndFlush(.init( - ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired))), promise: nil) + context.writeAndFlush( + .init( + ServerMessage.selectedAuthenticationMethod(.init(method: .noneRequired)) + ), + promise: nil + ) case .authenticationData: context.fireErrorCaught(MockSOCKSError(description: "Received authentication data but didn't receive any.")) case .request(let request): - context.writeAndFlush(.init( - ServerMessage.response(.init(reply: .succeeded, boundAddress: request.addressType))), promise: nil) - context.channel.pipeline.addHandlers([ - ByteToMessageHandler(HTTPRequestDecoder()), - HTTPResponseEncoder(), - ], position: .after(self)).whenSuccess { - context.channel.pipeline.removeHandler(self, promise: nil) - context.channel.pipeline.removeHandler(self.handshakeHandler, promise: nil) + context.writeAndFlush( + .init( + ServerMessage.response(.init(reply: .succeeded, boundAddress: request.addressType)) + ), + promise: nil + ) + + do { + try context.channel.pipeline.syncOperations.addHandlers( + [ + ByteToMessageHandler(HTTPRequestDecoder()), + HTTPResponseEncoder(), + ], + position: .after(self) + ) + context.channel.pipeline.syncOperations.removeHandler(self, promise: nil) + context.channel.pipeline.syncOperations.removeHandler(self.handshakeHandler, promise: nil) + } catch { + context.fireErrorCaught(error) } } } @@ -134,7 +164,12 @@ class TestHTTPServer: ChannelInboundHandler { break case .end: context.write(self.wrapOutboundOut(.head(.init(version: .http1_1, status: .ok))), promise: nil) - context.write(self.wrapOutboundOut(.body(.byteBuffer(context.channel.allocator.buffer(string: self.expectedResponse)))), promise: nil) + context.write( + self.wrapOutboundOut( + .body(.byteBuffer(context.channel.allocator.buffer(string: self.expectedResponse))) + ), + promise: nil + ) context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil) } } diff --git a/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift b/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift index 438c643d7..c7588cc7d 100644 --- a/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift +++ b/Tests/AsyncHTTPClientTests/SSLContextCacheTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOPosix import NIOSSL import XCTest +@testable import AsyncHTTPClient + final class SSLContextCacheTests: XCTestCase { func testRequestingSSLContextWorks() { let group = MultiThreadedEventLoopGroup(numberOfThreads: 1) @@ -27,9 +28,13 @@ final class SSLContextCacheTests: XCTestCase { XCTAssertNoThrow(try group.syncShutdownGracefully()) } - XCTAssertNoThrow(try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) } func testCacheWorks() { @@ -43,12 +48,20 @@ final class SSLContextCacheTests: XCTestCase { var firstContext: NIOSSLContext? var secondContext: NIOSSLContext? - XCTAssertNoThrow(firstContext = try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) - XCTAssertNoThrow(secondContext = try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + firstContext = try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) + XCTAssertNoThrow( + secondContext = try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) XCTAssertNotNil(firstContext) XCTAssertNotNil(secondContext) XCTAssert(firstContext === secondContext) @@ -65,16 +78,24 @@ final class SSLContextCacheTests: XCTestCase { var firstContext: NIOSSLContext? var secondContext: NIOSSLContext? - XCTAssertNoThrow(firstContext = try cache.sslContext(tlsConfiguration: .makeClientConfiguration(), - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + firstContext = try cache.sslContext( + tlsConfiguration: .makeClientConfiguration(), + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) // Second one has a _different_ TLSConfiguration. var testTLSConfig = TLSConfiguration.makeClientConfiguration() testTLSConfig.certificateVerification = .none - XCTAssertNoThrow(secondContext = try cache.sslContext(tlsConfiguration: testTLSConfig, - eventLoop: eventLoop, - logger: HTTPClient.loggingDisabled).wait()) + XCTAssertNoThrow( + secondContext = try cache.sslContext( + tlsConfiguration: testTLSConfig, + eventLoop: eventLoop, + logger: HTTPClient.loggingDisabled + ).wait() + ) XCTAssertNotNil(firstContext) XCTAssertNotNil(secondContext) XCTAssert(firstContext !== secondContext) diff --git a/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift b/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift index 4c5cd1816..587e6c64c 100644 --- a/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift +++ b/Tests/AsyncHTTPClientTests/StressGetHttpsTests.swift @@ -14,9 +14,6 @@ import AsyncHTTPClient import Atomics -#if canImport(Network) -import Network -#endif import Logging import NIOConcurrencyHelpers import NIOCore @@ -29,11 +26,17 @@ import NIOTestUtils import NIOTransportServices import XCTest +#if canImport(Network) +import Network +#endif + final class StressGetHttpsTests: XCTestCaseHTTPClientTestsBaseClass { func testStressGetHttps() throws { let localHTTPBin = HTTPBin(.http1_1(ssl: true)) - let localClient = HTTPClient(eventLoopGroupProvider: .shared(self.clientGroup), - configuration: HTTPClient.Configuration(certificateVerification: .none)) + let localClient = HTTPClient( + eventLoopGroupProvider: .shared(self.clientGroup), + configuration: HTTPClient.Configuration(certificateVerification: .none) + ) defer { XCTAssertNoThrow(try localClient.syncShutdown()) XCTAssertNoThrow(try localHTTPBin.shutdown()) @@ -43,7 +46,11 @@ final class StressGetHttpsTests: XCTestCaseHTTPClientTestsBaseClass { let requestCount = 200 var futureResults = [EventLoopFuture]() for _ in 1...requestCount { - let req = try HTTPClient.Request(url: "https://localhost:\(localHTTPBin.port)/get", method: .GET, headers: ["X-internal-delay": "100"]) + let req = try HTTPClient.Request( + url: "https://localhost:\(localHTTPBin.port)/get", + method: .GET, + headers: ["X-internal-delay": "100"] + ) futureResults.append(localClient.execute(request: req)) } XCTAssertNoThrow(try EventLoopFuture.andAllSucceed(futureResults, on: eventLoop).wait()) diff --git a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift index c119c7e50..988ba6e3f 100644 --- a/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift +++ b/Tests/AsyncHTTPClientTests/TLSEventsHandlerTests.swift @@ -12,13 +12,14 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOSSL import NIOTLS import XCTest +@testable import AsyncHTTPClient + class TLSEventsHandlerTests: XCTestCase { func testHandlerHappyPath() { let tlsEventsHandler = TLSEventsHandler(deadline: nil) @@ -38,7 +39,7 @@ class TLSEventsHandlerTests: XCTestCase { let embedded = EmbeddedChannel(handlers: [tlsEventsHandler]) XCTAssertNotNil(tlsEventsHandler.tlsEstablishedFuture) - XCTAssertNoThrow(try embedded.pipeline.removeHandler(tlsEventsHandler).wait()) + XCTAssertNoThrow(try embedded.pipeline.syncOperations.removeHandler(tlsEventsHandler).wait()) XCTAssertThrowsError(try XCTUnwrap(tlsEventsHandler.tlsEstablishedFuture).wait()) } diff --git a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift b/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift index a8d3d5a5e..a631e9a93 100644 --- a/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift +++ b/Tests/AsyncHTTPClientTests/Transaction+StateMachineTests.swift @@ -12,12 +12,13 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import NIOCore import NIOEmbedded import NIOHTTP1 import XCTest +@testable import AsyncHTTPClient + struct NoOpAsyncSequenceProducerDelegate: NIOAsyncSequenceProducerDelegate { func produceMore() {} func didTerminate() {} @@ -37,7 +38,10 @@ final class Transaction_StateMachineTests: XCTestCase { state.requestWasQueued(queuer) let failAction = state.fail(HTTPClientError.cancelled) - guard case .failResponseHead(_, let error, let scheduler, let rexecutor, let bodyStreamContinuation) = failAction else { + guard + case .failResponseHead(_, let error, let scheduler, let rexecutor, let bodyStreamContinuation) = + failAction + else { return XCTFail("Unexpected fail action: \(failAction)") } XCTAssertEqual(error as? HTTPClientError, .cancelled) @@ -88,7 +92,10 @@ final class Transaction_StateMachineTests: XCTestCase { XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) let failAction = state.fail(MyError()) - guard case .failResponseHead(let continuation, let error, nil, nil, bodyStreamContinuation: nil) = failAction else { + guard + case .failResponseHead(let continuation, let error, nil, nil, bodyStreamContinuation: nil) = + failAction + else { return XCTFail("Unexpected fail action: \(failAction)") } XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) @@ -118,7 +125,10 @@ final class Transaction_StateMachineTests: XCTestCase { XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) let failAction = state.fail(MyError()) - guard case .failResponseHead(let continuation, let error, nil, nil, bodyStreamContinuation: nil) = failAction else { + guard + case .failResponseHead(let continuation, let error, nil, nil, bodyStreamContinuation: nil) = + failAction + else { return XCTFail("Unexpected fail action: \(failAction)") } XCTAssertIdentical(scheduler as? MockTaskQueuer, queuer) @@ -203,7 +213,10 @@ final class Transaction_StateMachineTests: XCTestCase { XCTAssertEqual(state.willExecuteRequest(executor), .none) state.requestWasQueued(queuer) let head = HTTPResponseHead(version: .http1_1, status: .ok) - let receiveResponseHeadAction = state.receiveResponseHead(head, delegate: NoOpAsyncSequenceProducerDelegate()) + let receiveResponseHeadAction = state.receiveResponseHead( + head, + delegate: NoOpAsyncSequenceProducerDelegate() + ) guard case .succeedResponseHead(_, let continuation) = receiveResponseHeadAction else { return XCTFail("Unexpected action: \(receiveResponseHeadAction)") } @@ -258,7 +271,7 @@ extension Transaction.StateMachine.NextWriteAction: Equatable { public static func == (lhs: Self, rhs: Self) -> Bool { switch (lhs, rhs) { case (.writeAndWait(let lhsEx), .writeAndWait(let rhsEx)), - (.writeAndContinue(let lhsEx), .writeAndContinue(let rhsEx)): + (.writeAndContinue(let lhsEx), .writeAndContinue(let rhsEx)): if let lhsMock = lhsEx as? MockRequestExecutor, let rhsMock = rhsEx as? MockRequestExecutor { return lhsMock === rhsMock } diff --git a/Tests/AsyncHTTPClientTests/TransactionTests.swift b/Tests/AsyncHTTPClientTests/TransactionTests.swift index a8a2bb30e..3316de370 100644 --- a/Tests/AsyncHTTPClientTests/TransactionTests.swift +++ b/Tests/AsyncHTTPClientTests/TransactionTests.swift @@ -12,27 +12,27 @@ // //===----------------------------------------------------------------------===// -@testable import AsyncHTTPClient import Logging import NIOConcurrencyHelpers import NIOCore import NIOEmbedded +import NIOFoundationCompat import NIOHTTP1 import NIOPosix import XCTest +@testable import AsyncHTTPClient + @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) typealias PreparedRequest = HTTPClientRequest.Prepared @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) final class TransactionTests: XCTestCase { func testCancelAsyncRequest() { - // creating the `XCTestExpectation` off the main thread crashes on Linux with Swift 5.6 - // therefore we create it here as a workaround which works fine - let scheduledRequestCanceled = self.expectation(description: "scheduled request canceled") XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + let scheduledRequestCanceled = loop.makePromise(of: Void.self) + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -43,11 +43,11 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let queuer = MockTaskQueuer { _ in - scheduledRequestCanceled.fulfill() + scheduledRequestCanceled.succeed() } transaction.requestWasQueued(queuer) @@ -62,16 +62,14 @@ final class TransactionTests: XCTestCase { } // self.fulfillment(of:) is not available on Linux - _ = { - self.wait(for: [scheduledRequestCanceled], timeout: 1) - }() + try await scheduledRequestCanceled.futureResult.timeout(after: .seconds(1)).get() } } func testDeadlineExceededWhileQueuedAndExecutorImmediatelyCancelsTask() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -82,7 +80,7 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let queuer = MockTaskQueuer() @@ -91,11 +89,18 @@ final class TransactionTests: XCTestCase { transaction.deadlineExceeded() struct Executor: HTTPRequestExecutor { - func writeRequestBodyPart(_: NIOCore.IOData, request: AsyncHTTPClient.HTTPExecutableRequest, promise: NIOCore.EventLoopPromise?) { + func writeRequestBodyPart( + _: NIOCore.IOData, + request: AsyncHTTPClient.HTTPExecutableRequest, + promise: NIOCore.EventLoopPromise? + ) { XCTFail() } - func finishRequestBodyStream(_ task: AsyncHTTPClient.HTTPExecutableRequest, promise: NIOCore.EventLoopPromise?) { + func finishRequestBodyStream( + _ task: AsyncHTTPClient.HTTPExecutableRequest, + promise: NIOCore.EventLoopPromise? + ) { XCTFail() } @@ -118,8 +123,8 @@ final class TransactionTests: XCTestCase { func testResponseStreamingWorks() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -131,12 +136,12 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) transaction.willExecuteRequest(executor) @@ -177,8 +182,8 @@ final class TransactionTests: XCTestCase { func testIgnoringResponseBodyWorks() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -190,7 +195,7 @@ final class TransactionTests: XCTestCase { } var tuple: (Transaction, Task)! = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let transaction = tuple.0 @@ -199,9 +204,10 @@ final class TransactionTests: XCTestCase { let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) executor.runRequest(transaction) + await loop.run() let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["foo": "bar"]) XCTAssertFalse(executor.signalledDemandForResponseBody) @@ -225,8 +231,8 @@ final class TransactionTests: XCTestCase { func testWriteBackpressureWorks() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } let streamWriter = AsyncSequenceWriter() XCTAssertFalse(streamWriter.hasDemand, "Did not expect to have a demand at this point") @@ -242,26 +248,29 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() for i in 0..<100 { XCTAssertFalse(streamWriter.hasDemand, "Did not expect to have demand yet") transaction.resumeRequestBodyStream() - await streamWriter.demand() // wait's for the stream writer to signal demand + await streamWriter.demand() // wait's for the stream writer to signal demand transaction.pauseRequestBodyStream() let part = ByteBuffer(integer: i) streamWriter.write(part) - XCTAssertNoThrow(try executor.receiveRequestBody { - XCTAssertEqual($0, part) - }) + XCTAssertNoThrow( + try executor.receiveRequestBody { + XCTAssertEqual($0, part) + } + ) } transaction.resumeRequestBodyStream() @@ -305,12 +314,14 @@ final class TransactionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -351,8 +362,8 @@ final class TransactionTests: XCTestCase { func testSimplePostRequest() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .POST @@ -364,15 +375,18 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() executor.resumeRequestBodyStream() - XCTAssertNoThrow(try executor.receiveRequestBody { - XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") - }) + XCTAssertNoThrow( + try executor.receiveRequestBody { + XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") + } + ) XCTAssertNoThrow(try executor.receiveEndOfStream()) let responseHead = HTTPResponseHead(version: .http1_1, status: .ok, headers: ["foo": "bar"]) @@ -388,8 +402,8 @@ final class TransactionTests: XCTestCase { func testPostStreamFails() { XCTAsyncTest { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } let writer = AsyncSequenceWriter() @@ -403,19 +417,22 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) - let executor = MockRequestExecutor(eventLoop: embeddedEventLoop) + let executor = MockRequestExecutor(eventLoop: loop) executor.runRequest(transaction) + await loop.run() executor.resumeRequestBodyStream() await writer.demand() writer.write(.init(string: "Hello world!")) - XCTAssertNoThrow(try executor.receiveRequestBody { - XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") - }) + XCTAssertNoThrow( + try executor.receiveRequestBody { + XCTAssertEqual($0.getString(at: 0, length: $0.readableBytes), "Hello world!") + } + ) XCTAssertFalse(executor.isCancelled) struct WriteError: Error, Equatable {} @@ -430,8 +447,8 @@ final class TransactionTests: XCTestCase { func testResponseStreamFails() { XCTAsyncTest(timeout: 30) { - let embeddedEventLoop = EmbeddedEventLoop() - defer { XCTAssertNoThrow(try embeddedEventLoop.syncShutdownGracefully()) } + let loop = NIOAsyncTestingEventLoop() + defer { XCTAssertNoThrow(try loop.syncShutdownGracefully()) } var request = HTTPClientRequest(url: "https://localhost/") request.method = .GET @@ -443,12 +460,12 @@ final class TransactionTests: XCTestCase { } let (transaction, responseTask) = await Transaction.makeWithResultTask( request: preparedRequest, - preferredEventLoop: embeddedEventLoop + preferredEventLoop: loop ) let executor = MockRequestExecutor( pauseRequestBodyPartStreamAfterASingleWrite: true, - eventLoop: embeddedEventLoop + eventLoop: loop ) transaction.willExecuteRequest(executor) @@ -501,12 +518,14 @@ final class TransactionTests: XCTestCase { let connectionCreator = TestConnectionCreator() let delegate = TestHTTP2ConnectionDelegate() - var maybeHTTP2Connection: HTTP2Connection? - XCTAssertNoThrow(maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( - to: httpBin.port, - delegate: delegate, - on: eventLoop - )) + var maybeHTTP2Connection: HTTP2Connection.SendableView? + XCTAssertNoThrow( + maybeHTTP2Connection = try connectionCreator.createHTTP2Connection( + to: httpBin.port, + delegate: delegate, + on: eventLoop + ) + ) guard let http2Connection = maybeHTTP2Connection else { return XCTFail("Expected to have an HTTP2 connection here.") } @@ -517,7 +536,7 @@ final class TransactionTests: XCTestCase { var request = HTTPClientRequest(url: "https://localhost:\(httpBin.port)/") request.method = .POST request.headers = ["host": "localhost:\(httpBin.port)"] - request.body = .stream(streamWriter, length: .known(800)) + request.body = .stream(streamWriter, length: .known(Int64(800))) var maybePreparedRequest: PreparedRequest? XCTAssertNoThrow(maybePreparedRequest = try PreparedRequest(request)) @@ -567,22 +586,31 @@ final class TransactionTests: XCTestCase { // tasks. Since we want to wait for things to happen in tests, we need to `async let`, which creates // implicit tasks. Therefore we need to wrap our iterator struct. @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -actor SharedIterator where Wrapped.Element: Sendable { - private var wrappedIterator: Wrapped.AsyncIterator - private var nextCallInProgress: Bool = false +final class SharedIterator: Sendable where Wrapped.Element: Sendable { + private struct State: @unchecked Sendable { + var wrappedIterator: Wrapped.AsyncIterator + var nextCallInProgress: Bool = false + } + + private let state: NIOLockedValueBox init(_ sequence: Wrapped) { - self.wrappedIterator = sequence.makeAsyncIterator() + self.state = NIOLockedValueBox(State(wrappedIterator: sequence.makeAsyncIterator())) } func next() async throws -> Wrapped.Element? { - precondition(self.nextCallInProgress == false) - self.nextCallInProgress = true - var iter = self.wrappedIterator + var iter = self.state.withLockedValue { + precondition($0.nextCallInProgress == false) + $0.nextCallInProgress = true + return $0.wrappedIterator + } + defer { - precondition(self.nextCallInProgress == true) - self.nextCallInProgress = false - self.wrappedIterator = iter + self.state.withLockedValue { + precondition($0.nextCallInProgress == true) + $0.nextCallInProgress = false + $0.wrappedIterator = iter + } } return try await iter.next() } @@ -590,7 +618,7 @@ actor SharedIterator where Wrapped.Element: Sendable { /// non fail-able promise that only supports one observer @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -private actor Promise { +private actor Promise { private enum State { case initialised case fulfilled(Value) @@ -629,6 +657,35 @@ private actor Promise { @available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) extension Transaction { + #if compiler(>=6.0) + fileprivate static func makeWithResultTask( + request: sending PreparedRequest, + requestOptions: RequestOptions = .forTests(), + logger: Logger = Logger(label: "test"), + connectionDeadline: NIODeadline = .distantFuture, + preferredEventLoop: EventLoop + ) async -> (Transaction, _Concurrency.Task) { + let transactionPromise = Promise() + let task = Task { + try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in + let transaction = Transaction( + request: request, + requestOptions: requestOptions, + logger: logger, + connectionDeadline: connectionDeadline, + preferredEventLoop: preferredEventLoop, + responseContinuation: continuation + ) + Task { + await transactionPromise.fulfil(transaction) + } + } + } + + return (await transactionPromise.value, task) + } + #else fileprivate static func makeWithResultTask( request: PreparedRequest, requestOptions: RequestOptions = .forTests(), @@ -636,9 +693,17 @@ extension Transaction { connectionDeadline: NIODeadline = .distantFuture, preferredEventLoop: EventLoop ) async -> (Transaction, _Concurrency.Task) { + // It isn't sendable ... but on 6.0 and later we use 'sending'. + struct UnsafePrepareRequest: @unchecked Sendable { + var value: PreparedRequest + } + let transactionPromise = Promise() + let unsafe = UnsafePrepareRequest(value: request) let task = Task { - try await withCheckedThrowingContinuation { (continuation: CheckedContinuation) in + try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in + let request = unsafe.value let transaction = Transaction( request: request, requestOptions: requestOptions, @@ -655,4 +720,5 @@ extension Transaction { return (await transactionPromise.value, task) } + #endif } diff --git a/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift b/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift index e1d2e4592..6cdcf4f8a 100644 --- a/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift +++ b/Tests/AsyncHTTPClientTests/XCTest+AsyncAwait.swift @@ -11,21 +11,21 @@ // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// -/* - * Copyright 2021, gRPC Authors All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// +// Copyright 2021, gRPC Authors All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// import XCTest @@ -53,7 +53,7 @@ extension XCTestCase { try await operation() } catch { XCTFail("Error thrown while executing \(function): \(error)", file: file, line: line) - Thread.callStackSymbols.forEach { print($0) } + for symbol in Thread.callStackSymbols { print(symbol) } } expectation.fulfill() } diff --git a/docker/Dockerfile b/docker/Dockerfile deleted file mode 100644 index 2d1e57def..000000000 --- a/docker/Dockerfile +++ /dev/null @@ -1,34 +0,0 @@ -ARG swift_version=5.7 -ARG ubuntu_version=jammy -ARG base_image=swift:$swift_version-$ubuntu_version -FROM $base_image -# needed to do again after FROM due to docker limitation -ARG swift_version -ARG ubuntu_version - -# set as UTF-8 -RUN apt-get update && apt-get install -y locales locales-all -ENV LC_ALL en_US.UTF-8 -ENV LANG en_US.UTF-8 -ENV LANGUAGE en_US.UTF-8 - -# dependencies -RUN apt-get update && apt-get install -y wget -RUN apt-get update && apt-get install -y lsof dnsutils netcat-openbsd net-tools libz-dev curl jq # used by integration tests - -# ruby and jazzy for docs generation -RUN apt-get update && apt-get install -y ruby ruby-dev libsqlite3-dev build-essential -# jazzy no longer works on xenial as ruby is too old. -RUN if [ "${ubuntu_version}" = "focal" ] ; then echo "gem: --no-document" > ~/.gemrc; fi -RUN if [ "${ubuntu_version}" = "focal" ] ; then gem install jazzy; fi - -# tools -RUN mkdir -p $HOME/.tools -RUN echo 'export PATH="$HOME/.tools:$PATH"' >> $HOME/.profile - -# swiftformat (until part of the toolchain) - -ARG swiftformat_version=0.48.8 -RUN git clone --branch $swiftformat_version --depth 1 https://github.com/nicklockwood/SwiftFormat $HOME/.tools/swift-format -RUN cd $HOME/.tools/swift-format && swift build -c release -RUN ln -s $HOME/.tools/swift-format/.build/release/swiftformat $HOME/.tools/swiftformat diff --git a/docker/docker-compose.2204.510.yaml b/docker/docker-compose.2204.510.yaml deleted file mode 100644 index 8dbf21183..000000000 --- a/docker/docker-compose.2204.510.yaml +++ /dev/null @@ -1,22 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:22.04-5.10 - build: - args: - ubuntu_version: "jammy" - swift_version: "5.10" - - documentation-check: - image: async-http-client:22.04-5.10 - - test: - image: async-http-client:22.04-5.10 - environment: - - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:22.04-5.10 diff --git a/docker/docker-compose.2204.58.yaml b/docker/docker-compose.2204.58.yaml deleted file mode 100644 index 89b410ae2..000000000 --- a/docker/docker-compose.2204.58.yaml +++ /dev/null @@ -1,22 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:22.04-5.8 - build: - args: - ubuntu_version: "jammy" - swift_version: "5.8" - - documentation-check: - image: async-http-client:22.04-5.8 - - test: - image: async-http-client:22.04-5.8 - environment: - - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:22.04-5.8 diff --git a/docker/docker-compose.2204.59.yaml b/docker/docker-compose.2204.59.yaml deleted file mode 100644 index b125fff39..000000000 --- a/docker/docker-compose.2204.59.yaml +++ /dev/null @@ -1,22 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:22.04-5.9 - build: - args: - ubuntu_version: "jammy" - swift_version: "5.9" - - documentation-check: - image: async-http-client:22.04-5.9 - - test: - image: async-http-client:22.04-5.9 - environment: - - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:22.04-5.9 diff --git a/docker/docker-compose.2204.main.yaml b/docker/docker-compose.2204.main.yaml deleted file mode 100644 index 8dfa4c921..000000000 --- a/docker/docker-compose.2204.main.yaml +++ /dev/null @@ -1,21 +0,0 @@ -version: "3" - -services: - - runtime-setup: - image: async-http-client:22.04-main - build: - args: - base_image: "swiftlang/swift:nightly-main-jammy" - - documentation-check: - image: async-http-client:22.04-main - - test: - image: async-http-client:22.04-main - environment: - - IMPORT_CHECK_ARG=--explicit-target-dependency-import-check error - #- SANITIZER_ARG=--sanitize=thread - - shell: - image: async-http-client:22.04-main diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml deleted file mode 100644 index 9ac4a6eea..000000000 --- a/docker/docker-compose.yaml +++ /dev/null @@ -1,45 +0,0 @@ -# this file is not designed to be run directly -# instead, use the docker-compose.. files -# eg docker-compose -f docker/docker-compose.yaml -f docker/docker-compose.1804.50.yaml run test -version: "3" - -services: - - runtime-setup: - image: async-http-client:default - build: - context: . - dockerfile: Dockerfile - - common: &common - image: async-http-client:default - depends_on: [runtime-setup] - volumes: - - ~/.ssh:/root/.ssh - - ..:/code:z - working_dir: /code - cap_drop: - - CAP_NET_RAW - - CAP_NET_BIND_SERVICE - - soundness: - <<: *common - command: /bin/bash -xcl "./scripts/soundness.sh" - - documentation-check: - <<: *common - command: /bin/bash -xcl "./scripts/check-docs.sh" - - test: - <<: *common - command: /bin/bash -xcl "swift test --parallel -Xswiftc -warnings-as-errors --enable-test-discovery $${SANITIZER_ARG-} $${IMPORT_CHECK_ARG-}" - - # util - - shell: - <<: *common - entrypoint: /bin/bash - - docs: - <<: *common - command: /bin/bash -cl "./scripts/generate_docs.sh" diff --git a/scripts/check-docs.sh b/scripts/check-docs.sh deleted file mode 100755 index 61a13a56f..000000000 --- a/scripts/check-docs.sh +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2023 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu - -raw_targets=$(sed -E -n -e 's/^.* - documentation_targets: \[(.*)\].*$/\1/p' .spi.yml) -targets=(${raw_targets//,/ }) - -for target in "${targets[@]}"; do - swift package plugin generate-documentation --target "$target" --warnings-as-errors --analyze --level detailed -done diff --git a/scripts/check_no_api_breakages.sh b/scripts/check_no_api_breakages.sh deleted file mode 100755 index 2d7028617..000000000 --- a/scripts/check_no_api_breakages.sh +++ /dev/null @@ -1,68 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2022 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -##===----------------------------------------------------------------------===## -## -## This source file is part of the SwiftNIO open source project -## -## Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of SwiftNIO project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu - -function usage() { - echo >&2 "Usage: $0 REPO-GITHUB-URL NEW-VERSION OLD-VERSIONS..." - echo >&2 - echo >&2 "This script requires a Swift 5.6+ toolchain." - echo >&2 - echo >&2 "Examples:" - echo >&2 - echo >&2 "Check between main and tag 1.9.0 of async-http-client:" - echo >&2 " $0 https://github.com/swift-server/async-http-client main 1.9.0" - echo >&2 - echo >&2 "Check between HEAD and commit 64cf63d7 using the provided toolchain:" - echo >&2 " xcrun --toolchain org.swift.5120190702a $0 ../some-local-repo HEAD 64cf63d7" -} - -if [[ $# -lt 3 ]]; then - usage - exit 1 -fi - -tmpdir=$(mktemp -d /tmp/.check-api_XXXXXX) -repo_url=$1 -new_tag=$2 -shift 2 - -repodir="$tmpdir/repo" -git clone "$repo_url" "$repodir" -git -C "$repodir" fetch -q origin '+refs/pull/*:refs/remotes/origin/pr/*' -cd "$repodir" -git checkout -q "$new_tag" - -for old_tag in "$@"; do - echo "Checking public API breakages from $old_tag to $new_tag" - - swift package diagnose-api-breaking-changes "$old_tag" -done - -echo done diff --git a/scripts/generate_contributors_list.sh b/scripts/generate_contributors_list.sh deleted file mode 100755 index 00c162638..000000000 --- a/scripts/generate_contributors_list.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu -here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -contributors=$( cd "$here"/.. && git shortlog -es | cut -f2 | sed 's/^/- /' ) - -cat > "$here/../CONTRIBUTORS.txt" <<- EOF - For the purpose of tracking copyright, this is the list of individuals and - organizations who have contributed source code to the AsyncHTTPClient. - - For employees of an organization/company where the copyright of work done - by employees of that company is held by the company itself, only the company - needs to be listed here. - - ## COPYRIGHT HOLDERS - - - Apple Inc. (all contributors with '@apple.com') - - ### Contributors - - $contributors - - **Updating this list** - - Please do not edit this file manually. It is generated using \`./scripts/generate_contributors_list.sh\`. If a name is misspelled or appearing multiple times: add an entry in \`./.mailmap\` -EOF diff --git a/scripts/generate_docs.sh b/scripts/generate_docs.sh deleted file mode 100755 index 82da814d3..000000000 --- a/scripts/generate_docs.sh +++ /dev/null @@ -1,114 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -e - -my_path="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" -root_path="$my_path/.." -version=$(git describe --abbrev=0 --tags || echo "main") -modules=(AsyncHTTPClient) - -if [[ "$(uname -s)" == "Linux" ]]; then - # build code if required - if [[ ! -d "$root_path/.build/x86_64-unknown-linux" ]]; then - swift build - fi - # setup source-kitten if required - mkdir -p "$root_path/.build/sourcekitten" - source_kitten_source_path="$root_path/.build/sourcekitten/source" - if [[ ! -d "$source_kitten_source_path" ]]; then - git clone https://github.com/jpsim/SourceKitten.git "$source_kitten_source_path" - fi - source_kitten_path="$source_kitten_source_path/.build/debug" - if [[ ! -d "$source_kitten_path" ]]; then - rm -rf "$source_kitten_source_path/.swift-version" - cd "$source_kitten_source_path" && swift build && cd "$root_path" - fi - # generate - for module in "${modules[@]}"; do - if [[ ! -f "$root_path/.build/sourcekitten/$module.json" ]]; then - "$source_kitten_path/sourcekitten" doc --spm --module-name $module > "$root_path/.build/sourcekitten/$module.json" - fi - done -fi - -[[ -d docs/$version ]] || mkdir -p docs/$version -[[ -d async-http-client.xcodeproj ]] || swift package generate-xcodeproj - -# run jazzy -if ! command -v jazzy > /dev/null; then - gem install jazzy --no-ri --no-rdoc -fi - -jazzy_dir="$root_path/.build/jazzy" -rm -rf "$jazzy_dir" -mkdir -p "$jazzy_dir" - -module_switcher="$jazzy_dir/README.md" -jazzy_args=(--clean - --author 'AsyncHTTPClient team' - --readme "$module_switcher" - --author_url https://github.com/swift-server/async-http-client - --github_url https://github.com/swift-server/async-http-client - --github-file-prefix "https://github.com/swift-server/async-http-client/tree/$version" - --theme fullwidth - --xcodebuild-arguments -scheme,async-http-client-Package) -cat > "$module_switcher" <<"EOF" -# AsyncHTTPClient Docs - -AsyncHTTPClient is a Swift HTTP Client package. - -To get started with AsyncHTTPClient, [`import AsyncHTTPClient`](../AsyncHTTPClient/index.html). The -most important type is [`HTTPClient`](https://swift-server.github.io/async-http-client/docs/current/AsyncHTTPClient/Classes/HTTPClient.html) -which you can use to emit log messages. - -EOF - -tmp=`mktemp -d` -for module in "${modules[@]}"; do - args=("${jazzy_args[@]}" --output "$jazzy_dir/docs/$version/$module" --docset-path "$jazzy_dir/docset/$version/$module" - --module "$module" --module-version $version - --root-url "https://swift-server.github.io/async-http-client/docs/$version/$module/") - if [[ -f "$root_path/.build/sourcekitten/$module.json" ]]; then - args+=(--sourcekitten-sourcefile "$root_path/.build/sourcekitten/$module.json") - fi - jazzy "${args[@]}" -done - -# push to github pages -if [[ $PUSH == true ]]; then - BRANCH_NAME=$(git rev-parse --abbrev-ref HEAD) - GIT_AUTHOR=$(git --no-pager show -s --format='%an <%ae>' HEAD) - git fetch origin +gh-pages:gh-pages - git checkout gh-pages - rm -rf "docs/$version" - rm -rf "docs/current" - cp -r "$jazzy_dir/docs/$version" docs/ - cp -r "docs/$version" docs/current - git add --all docs - echo '' > index.html - git add index.html - touch .nojekyll - git add .nojekyll - changes=$(git diff-index --name-only HEAD) - if [[ -n "$changes" ]]; then - echo -e "changes detected\n$changes" - git commit --author="$GIT_AUTHOR" -m "publish $version docs" - git push origin gh-pages - else - echo "no changes detected" - fi - git checkout -f $BRANCH_NAME -fi diff --git a/scripts/soundness.sh b/scripts/soundness.sh deleted file mode 100755 index 216eab206..000000000 --- a/scripts/soundness.sh +++ /dev/null @@ -1,152 +0,0 @@ -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) 2018-2022 Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## - -set -eu -here="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -function replace_acceptable_years() { - # this needs to replace all acceptable forms with 'YEARS' - sed -e 's/20[12][0-9]-20[12][0-9]/YEARS/' -e 's/20[12][0-9]/YEARS/' -} - -printf "=> Checking for unacceptable language... " -# This greps for unacceptable terminology. The square bracket[s] are so that -# "git grep" doesn't find the lines that greps :). -unacceptable_terms=( - -e blacklis[t] - -e whitelis[t] - -e slav[e] - -e sanit[y] -) -if git grep --color=never -i "${unacceptable_terms[@]}" > /dev/null; then - printf "\033[0;31mUnacceptable language found.\033[0m\n" - git grep -i "${unacceptable_terms[@]}" - exit 1 -fi -printf "\033[0;32mokay.\033[0m\n" - -printf "=> Checking format... " -FIRST_OUT="$(git status --porcelain)" -swiftformat . > /dev/null 2>&1 -SECOND_OUT="$(git status --porcelain)" -if [[ "$FIRST_OUT" != "$SECOND_OUT" ]]; then - printf "\033[0;31mformatting issues!\033[0m\n" - git --no-pager diff - exit 1 -else - printf "\033[0;32mokay.\033[0m\n" -fi - -printf "=> Checking license headers\n" -tmp=$(mktemp /tmp/.async-http-client-soundness_XXXXXX) - -for language in swift-or-c bash dtrace; do - printf " * $language... " - declare -a matching_files - declare -a exceptions - expections=( ) - matching_files=( -name '*' ) - case "$language" in - swift-or-c) - exceptions=( -name c_nio_http_parser.c -o -name c_nio_http_parser.h -o -name cpp_magic.h -o -name Package.swift -o -name CNIOSHA1.h -o -name c_nio_sha1.c -o -name ifaddrs-android.c -o -name ifaddrs-android.h -o -name 'Package@swift*.swift' ) - matching_files=( -name '*.swift' -o -name '*.c' -o -name '*.h' ) - cat > "$tmp" <<"EOF" -//===----------------------------------------------------------------------===// -// -// This source file is part of the AsyncHTTPClient open source project -// -// Copyright (c) YEARS Apple Inc. and the AsyncHTTPClient project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// -EOF - ;; - bash) - matching_files=( -name '*.sh' ) - cat > "$tmp" <<"EOF" -#!/bin/bash -##===----------------------------------------------------------------------===## -## -## This source file is part of the AsyncHTTPClient open source project -## -## Copyright (c) YEARS Apple Inc. and the AsyncHTTPClient project authors -## Licensed under Apache License v2.0 -## -## See LICENSE.txt for license information -## See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors -## -## SPDX-License-Identifier: Apache-2.0 -## -##===----------------------------------------------------------------------===## -EOF - ;; - dtrace) - matching_files=( -name '*.d' ) - cat > "$tmp" <<"EOF" -#!/usr/sbin/dtrace -q -s -/*===----------------------------------------------------------------------===* - * - * This source file is part of the AsyncHTTPClient open source project - * - * Copyright (c) YEARS Apple Inc. and the AsyncHTTPClient project authors - * Licensed under Apache License v2.0 - * - * See LICENSE.txt for license information - * See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors - * - * SPDX-License-Identifier: Apache-2.0 - * - *===----------------------------------------------------------------------===*/ -EOF - ;; - *) - echo >&2 "ERROR: unknown language '$language'" - ;; - esac - - expected_lines=$(cat "$tmp" | wc -l) - expected_sha=$(cat "$tmp" | shasum) - - ( - cd "$here/.." - find . \ - \( \! -path './.build/*' -a \ - \( "${matching_files[@]}" \) -a \ - \( \! \( "${exceptions[@]}" \) \) \) | while read line; do - if [[ "$(cat "$line" | replace_acceptable_years | head -n $expected_lines | shasum)" != "$expected_sha" ]]; then - printf "\033[0;31mmissing headers in file '$line'!\033[0m\n" - diff -u <(cat "$line" | replace_acceptable_years | head -n $expected_lines) "$tmp" - exit 1 - fi - done - printf "\033[0;32mokay.\033[0m\n" - ) -done - -rm "$tmp" - -# This checks for the umbrella NIO module. -printf "=> Checking for imports of umbrella NIO module... " -if git grep --color=never -i "^[ \t]*import \+NIO[ \t]*$" > /dev/null; then - printf "\033[0;31mUmbrella imports found.\033[0m\n" - git grep -i "^[ \t]*import \+NIO[ \t]*$" - exit 1 -fi -printf "\033[0;32mokay.\033[0m\n"