diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 000000000..a32c57bb2 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,5 @@ +# Handle line endings automatically for files detected as text +* text=auto + +# These files are text and should be normalized +*.java text diff=java \ No newline at end of file diff --git a/.github/workflows/gradle-all.yml b/.github/workflows/gradle-all.yml new file mode 100644 index 000000000..abbd14106 --- /dev/null +++ b/.github/workflows/gradle-all.yml @@ -0,0 +1,152 @@ +name: Branches Java CI + +on: + # Trigger the workflow on push + # but only for the non master/1.0.x branches + push: + branches-ignore: + - 1.1.x + - master + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon + + publish: + needs: [ build, coretest, othertest, jcstress ] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Publish Packages to Artifactory + if: ${{ matrix.jdk == '1.8' }} + run: | + githubRef="${githubRef#refs/heads/}" + githubRef="${githubRef////-}" + ./gradlew -PversionSuffix="-${githubRef}-SNAPSHOT" -PbuildNumber="${buildNumber}" publishMavenPublicationToGitHubPackagesRepository --no-daemon --stacktrace + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + githubRef: ${{ github.ref }} + buildNumber: ${{ github.run_number }} \ No newline at end of file diff --git a/.github/workflows/gradle-main.yml b/.github/workflows/gradle-main.yml new file mode 100644 index 000000000..33bca8e72 --- /dev/null +++ b/.github/workflows/gradle-main.yml @@ -0,0 +1,161 @@ +name: Main Branches Java CI + +on: + # Trigger the workflow on push + # but only for the master/1.1.x branch + push: + branches: + - master + - 1.1.x + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon + + publish: + needs: [ build, coretest, othertest, jcstress ] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Publish Packages to Artifactory + if: ${{ matrix.jdk == '1.8' }} + run: ./gradlew -PversionSuffix="-SNAPSHOT" -PbuildNumber="${buildNumber}" publishMavenPublicationToSonatypeRepository --no-daemon --stacktrace + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + buildNumber: ${{ github.run_number }} + ORG_GRADLE_PROJECT_signingKey: ${{secrets.signingKey}} + ORG_GRADLE_PROJECT_signingPassword: ${{secrets.signingPassword}} + ORG_GRADLE_PROJECT_sonatypeUsername: ${{secrets.sonatypeUsername}} + ORG_GRADLE_PROJECT_sonatypePassword: ${{secrets.sonatypePassword}} + - name: Aggregate test reports with ciMate + if: always() + continue-on-error: true + env: + CIMATE_PROJECT_ID: m84qx17y + run: | + wget -q https://get.cimate.io/release/linux/cimate + chmod +x cimate + ./cimate "**/TEST-*.xml" \ No newline at end of file diff --git a/.github/workflows/gradle-pr.yml b/.github/workflows/gradle-pr.yml new file mode 100644 index 000000000..cecca085f --- /dev/null +++ b/.github/workflows/gradle-pr.yml @@ -0,0 +1,111 @@ +name: Pull Request Java CI + +on: [pull_request] + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test --no-daemon + + coretest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew rsocket-core:test --no-daemon + + othertest: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew test -x :rsocket-core:test --no-daemon + + jcstress: + needs: [build] + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + jdk: [ 1.8, 11, 17 ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK ${{ matrix.jdk }} + uses: actions/setup-java@v1 + with: + java-version: ${{ matrix.jdk }} + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew jcstress --no-daemon \ No newline at end of file diff --git a/.github/workflows/gradle-release.yml b/.github/workflows/gradle-release.yml new file mode 100644 index 000000000..922eb0e3e --- /dev/null +++ b/.github/workflows/gradle-release.yml @@ -0,0 +1,44 @@ +name: Release Java CI + +on: + # Trigger the workflow on push + push: + # Sequence of patterns matched against refs/tags + tags: + - '*' # Push events to matching *, i.e. 1.0, 20.15.10 + +jobs: + publish: + + runs-on: ${{ matrix.os }} + + strategy: + matrix: + os: [ ubuntu-latest ] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + - name: Set up JDK 1.8 + uses: actions/setup-java@v1 + with: + java-version: 1.8 + - name: Cache Gradle packages + uses: actions/cache@v1 + with: + path: ~/.gradle/caches + key: ${{ runner.os }}-gradle-${{ hashFiles('**/*.gradle') }} + restore-keys: ${{ runner.os }}-gradle + - name: Grant execute permission for gradlew + run: chmod +x gradlew + - name: Build with Gradle + run: ./gradlew clean build -x test + - name: Publish Packages to Sonotype + run: ./gradlew -Pversion="${githubRef#refs/tags/}" -PbuildNumber="${buildNumber}" sign publishMavenPublicationToSonatypeRepository + env: + githubRef: ${{ github.ref }} + buildNumber: ${{ github.run_number }} + ORG_GRADLE_PROJECT_signingKey: ${{secrets.signingKey}} + ORG_GRADLE_PROJECT_signingPassword: ${{secrets.signingPassword}} + ORG_GRADLE_PROJECT_sonatypeUsername: ${{secrets.sonatypeUsername}} + ORG_GRADLE_PROJECT_sonatypePassword: ${{secrets.sonatypePassword}} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 3be3ba898..92865ccca 100644 --- a/.gitignore +++ b/.gitignore @@ -65,4 +65,14 @@ atlassian-ide-plugin.xml # NetBeans specific files/directories .nbattrs -/bin +**/bin/* + +#.gitignore in subdirectory +.gitignore + +### infer ### +# infer- http://fbinfer.com/ +infer-out +*/infer-out +.inferConfig +*/.inferConfig diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 0039232f7..000000000 --- a/.travis.yml +++ /dev/null @@ -1,26 +0,0 @@ -language: java -jdk: -- oraclejdk8 - -# force upgrade Java8 as per https://github.com/travis-ci/travis-ci/issues/4042 (fixes compilation issue) -addons: - apt: - packages: - - oracle-java8-installer - -sudo: false -# as per http://blog.travis-ci.com/2014-12-17-faster-builds-with-container-based-infrastructure/ - -# script for build and release via Travis to Bintray -script: gradle/buildViaTravis.sh - -# cache between builds -cache: - directories: - - $HOME/.m2 - - $HOME/.gradle - -env: - global: - - secure: cXHzd2WHqmdmJEyEKlELt8Rp9qCvhTRXTEHpQz0sKt55KorI8vO33sSOBs8uBqknWgGgOzHsB7cw0dJRxCmW+BRy90ELtdg/dVLzU8D8BrI6/DHzd/Bhyt9wx2eVdLmDV7lQ113AqJ7lphbH+U8ceTBlbNYDPKcIjFhsPO0WcPxQYed45na8XRK0UcAOpVmmNlTE6fHy5acQblNO84SN6uevCFqWAZJY7rc6xGrzFzca+ul5kR8xIzdE5jKs2Iw0MDeWi8cshkhj9c0FDtfsNIB1F+NafDtEdqjt6kMqYAUUiTAM2QdNoffzgmWEbVOj3uvthlm+S11XaU3Cn2uC7CiZTn2ebuoqCuV5Ge6KQI0ysEQVUfLhIF7iJG6dJvoyYy8ta8LEcjcsYAdF34BVddoUJkp+eJuhlto2aTZsDdXpmnwRM1PPDRoyrLjRcKiWYPR2tO2RG9sb0nRAGEpHTDd5ju2Ta4zpvgpWGUiKprs5R+YY7TEg16VSTYMmCJj5C9ap2lYIH4EoxsQpuxYig9AV1sOUJujLSa4TXqlcOmSM0IkHJ/i0VE8TZg4nV4XowyH6nKZ63InF4pUDcG13BpJQyTFKbK2D0lFn8MzpWvIV2oOUxNkOaOBg9cGhAnv9Sfw/Iv1UVaUgCNQd2M0R0rwfJoPCg2mmWVxsvh3cW4M= - - secure: UKZHoS/uw6SuAI9n0lCHWEc74H9+STpdvMmLd+nANjWkMFfo0bOUbm1SsV6PU6d2r8C5k4dEsW90J4diunR856R8vO+DpJPwUNJDuLm2Kiv7zhLJrXqpRTw3E3ijdFA84xocTN1CxZakW+ZP2wnb83jI3p99rgotc0i1wz9n1onaZrhZK5c3Rod2cdRig0wkeKK9NhwupXbXkpPtRNFRCOPgKvjPiEeW5YRZ/YxOs+OL9Sy6764b46EiWP/DFPGOTkJwz2mxLRT8sBx6rjeyf6v41NQPW1rlNwIDKcpnQl1n49k5SgARZvhFlakRdLyzljj1L0/VLk7xNDEQx3LYxl2mSl7AQlA8RYkxqirMRnIHHXrA7hhPuCYxp2nlpciwuh69vAOfliL3JeAsEgj0PKiQp7HQyBPQOvfmiGH2oIo+dkXvQwmLZTDnj9vNzZIS+rADbZoLzKftZKAUIWCze5zQ6mCkwKiuVYYWl2aPoy2XxRkA51t6sEHA0/iYrqaOX76WHGH0JhoAGWEIBNNP/rRnO38Rm96pm6SHrzLa1VxVFT6dRGljFTxvCsgsCfx/rRL+a1E0j89nLAmOGkDpyhUaKWqVQJWk3H1AeQ3cWGXfvUhDyaSTxcKs6AuQ2E5TtNgkbx0Xjq8NDjuiP57WDFYMXGvIqkgSzKG3A0DSMHI= diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 000000000..ef7dd9dda --- /dev/null +++ b/AUTHORS @@ -0,0 +1,21 @@ +benjchristensen = Ben Christensen +gregwhitaker = Greg Whitaker +junaidkhalid = Junaid Khalid +kojilin = Kang-Sze Lin +krisskross = Kristoffer Sjogren +ktoso = Konrad Malawski +lehecka = Ondrej Lehecka +lexs = Alexander Blom +mostroverkhov = Maksym Ostroverkhov +nebhale = Ben Hale +NiteshKant = Nitesh Kant +qweek = Alex Novoselov +rdegnan = Ryland Degnan +robertroeser = Robert Roeser +rstoyanchev = Rossen Stoyanchev +simonbasle = Simon Baslé +somasun = somasun +stevegury = Steve Gury +tmontgomery = Todd L. Montgomery +yschimke = Yuri Schimke +OlegDokuka = Oleh Dokuka diff --git a/CHANGES.md b/CHANGES.md index 265b6c45c..b8388885c 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,3 @@ -# ReactiveSocket Releases # +# RSocket Releases # No releases yet. \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1025ac4b4..56a5a7b69 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,4 +1,4 @@ -# Contributing to ReactiveSocket +# Contributing to RSocket If you would like to contribute code you can do so through GitHub by forking the repository and sending a pull request (on a branch other than `master`, `0.x`, `1.x`, or `gh-pages`). @@ -6,22 +6,22 @@ When submitting code, please make every effort to follow existing conventions an ## License -By contributing your code, you agree to license your contribution under the terms of the APLv2: https://github.com/ReactiveSocket/reactivesocket-java/blob/master/LICENSE +By contributing your code, you agree to license your contribution under the terms of the APLv2: https://github.com/rsocket/rsocket-java/blob/1.0.x/LICENSE All files are released with the Apache 2.0 license. If you are adding a new file it should have a header like this: -``` -/** - * Copyright 2015 Netflix, Inc. - * +```java +/* + * Copyright 2015-2018 the original author or authors. + * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * + * + * http://www.apache.org/licenses/LICENSE-2.0 + * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/NOTICE b/NOTICE new file mode 100644 index 000000000..ea8e324f1 --- /dev/null +++ b/NOTICE @@ -0,0 +1,15 @@ +RSocket Java + +Copyright 2015-2018 the original author or authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/README.md b/README.md index d05fce64d..7ed3244b8 100644 --- a/README.md +++ b/README.md @@ -1,63 +1,145 @@ -# ReactiveSocket +# RSocket -ReactiveSocket is network protocol with client and server that uses [Reactive Streams](http://reactive-streams.org) (and optimistically [Reactive Streams IO](http://reactive-streams.io) as it gets defined). +[![Join the chat at https://gitter.im/RSocket/RSocket-Java](https://badges.gitter.im/rsocket/rsocket-java.svg)](https://gitter.im/rsocket/rsocket-java) -It enables the following interaction models via async message passing over a single network connection: +RSocket is a binary protocol for use on byte stream transports such as TCP, WebSockets, and Aeron. + +It enables the following symmetric interaction models via async message passing over a single connection: - request/response (stream of 1) - request/stream (finite stream of many) - fire-and-forget (no response) - event subscription (infinite stream of many) -This is the core project for Java that implements the protocol and exposes Reactive Stream APIs. Typically most use will come via another library that uses this one. - -For example: +Learn more at http://rsocket.io -- ReactiveSocket over WebSockets using Netty [reactivesocket-websockets-netty](https://github.com/ReactiveSocket/reactivesocket-java-impl) -- ReactiveSocket over TCP using Netty [reactivesocket-tcp-netty](https://github.com/ReactiveSocket/reactivesocket-java-impl) -- ReactiveSocket over Aeron using Aeron Java [reactivesocket-aeron-java](https://github.com/ReactiveSocket/reactivesocket-java-impl) +## Build and Binaries -ReactiveSocket is for communicating across network boundaries thus it is intended to be polyglot. Common libraries include: +[![Build Status](https://github.com/rsocket/rsocket-java/actions/workflows/gradle-main.yml/badge.svg?branch=master)](https://github.com/rsocket/rsocket-java/actions/workflows/gradle-main.yml) -- ReactiveSocket over WebSockets using Javascript (for Node.js and browsers) [reactivesocket-websockets-javascript](https://github.com/ReactiveSocket/reactivesocket-websockets-javascript) +⚠️ The `master` branch is now dedicated to development of the `1.2.x` line. -Others can be found in the [ReactiveSocket Github](https://github.com/ReactiveSocket) project. +Releases and milestones are available via Maven Central. -## Build and Binaries +Example: - +```groovy +repositories { + mavenCentral() + maven { url 'https://repo.spring.io/milestone' } // Reactor milestones (if needed) +} +dependencies { + implementation 'io.rsocket:rsocket-core:1.2.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.2.0-SNAPSHOT' +} +``` -Snapshots are available via JFrog. +Snapshots are available via [oss.jfrog.org](oss.jfrog.org) (OJO). Example: ```groovy repositories { - maven { url 'https://oss.jfrog.org/libs-snapshot' } + maven { url 'https://maven.pkg.github.com/rsocket/rsocket-java' } + maven { url 'https://repo.spring.io/snapshot' } // Reactor snapshots (if needed) } - dependencies { - compile 'io.reactivesocket:reactivesocket:0.0.1-SNAPSHOT' + implementation 'io.rsocket:rsocket-core:1.2.0-SNAPSHOT' + implementation 'io.rsocket:rsocket-transport-netty:1.2.0-SNAPSHOT' } ``` -No releases to Maven Central or JCenter have occurred yet. +## Development + +Install the google-java-format in Intellij, from Plugins preferences. +Enable under Preferences -> Other Settings -> google-java-format Settings + +Format automatically with + +``` +$./gradlew goJF +``` + +## Debugging +Frames can be printed out to help debugging. Set the logger `io.rsocket.FrameLogger` to debug to print the frames. + +## Requirements + +- Java 8 - heavy dependence on Java 8 functional APIs and java.time, also on Reactor +- Android O - https://github.com/rsocket/rsocket-demo-android-java8 + +## Trivial Client + +```java +package io.rsocket.transport.netty; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.DefaultPayload; +import reactor.core.publisher.Flux; + +import java.net.URI; + +public class ExampleClient { + public static void main(String[] args) { + WebsocketClientTransport ws = WebsocketClientTransport.create(URI.create("ws://rsocket-demo.herokuapp.com/ws")); + RSocket clientRSocket = RSocketConnector.connectWith(ws).block(); + + try { + Flux s = clientRSocket.requestStream(DefaultPayload.create("peace")); + + s.take(10).doOnNext(p -> System.out.println(p.getDataUtf8())).blockLast(); + } finally { + clientRSocket.dispose(); + } + } +} +``` + +## Zero Copy +By default to make RSocket easier to use it copies the incoming Payload. Copying the payload comes at cost to performance +and latency. If you want to use zero copy you must disable this. To disable copying you must include a `payloadDecoder` +argument in your `RSocketFactory`. This will let you manage the Payload without copying the data from the underlying +transport. You must free the Payload when you are done with them +or you will get a memory leak. Used correctly this will reduce latency and increase performance. + +### Example Server setup +```java +RSocketServer.create(new PingHandler()) + // Enable Zero Copy + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create(7878)) + .block() + .onClose() + .block(); +``` + +### Example Client setup +```java +RSocket clientRSocket = + RSocketConnector.create() + // Enable Zero Copy + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(TcpClientTransport.create(7878)) + .block(); +``` ## Bugs and Feedback -For bugs, questions and discussions please use the [Github Issues](https://github.com/ReactiveSocket/reactivesocket-java/issues). +For bugs, questions and discussions please use the [Github Issues](https://github.com/RSocket/reactivesocket-java/issues). - ## LICENSE -Copyright 2015 Netflix, Inc. +Copyright 2015-2020 the original author or authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - +http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 000000000..656e2de4b --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,47 @@ +## Usage of JMH tasks + +Only execute specific benchmark(s) (wildcards are added before and after): +``` +../gradlew jmh --include="(BenchmarkPrimary|OtherBench)" +``` +If you want to specify the wildcards yourself, you can pass the full regexp: +``` +../gradlew jmh --fullInclude=.*MyBenchmark.* +``` + +Specify extra profilers: +``` +../gradlew jmh --profilers="gc,stack" +``` + +Prominent profilers (for full list call `jmhProfilers` task): +- comp - JitCompilations, tune your iterations +- stack - which methods used most time +- gc - print garbage collection defaultWeightedStats +- hs_thr - thread usage + +Change report format from JSON to one of [CSV, JSON, NONE, SCSV, TEXT]: +``` +./gradlew jmh --format=csv +``` + +Specify JVM arguments: +``` +../gradlew jmh --jvmArgs="-Dtest.cluster=local" +``` + +Run in verification mode (execute benchmarks with minimum of fork/warmup-/benchmark-iterations): +``` +../gradlew jmh --verify=true +``` + +## Comparing with the baseline +If you wish you run two sets of benchmarks, one for the current change and another one for the "baseline", +there is an additional task `jmhBaseline` that will use the latest release: +``` +../gradlew jmh jmhBaseline --include=MyBenchmark +``` + +## Resources +- http://tutorials.jenkov.com/java-performance/jmh.html (Introduction) +- http://hg.openjdk.java.net/code-tools/jmh/file/tip/jmh-samples/src/main/java/org/openjdk/jmh/samples/ (Samples) diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle new file mode 100644 index 000000000..74e571d1f --- /dev/null +++ b/benchmarks/build.gradle @@ -0,0 +1,170 @@ +apply plugin: 'java' +apply plugin: 'idea' + +configurations { + current + baseline { + resolutionStrategy.cacheChangingModulesFor 0, 'seconds' + } +} + +dependencies { + // Use the baseline to avoid using new APIs in the benchmarks + compileOnly "io.rsocket:rsocket-core:${perfBaselineVersion}" + compileOnly "io.rsocket:rsocket-transport-local:${perfBaselineVersion}" + compileOnly "io.rsocket:rsocket-transport-netty:${perfBaselineVersion}" + + implementation "org.openjdk.jmh:jmh-core:1.35" + annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:1.35" + + current project(':rsocket-core') + current project(':rsocket-transport-local') + current project(':rsocket-transport-netty') + baseline "io.rsocket:rsocket-core:${perfBaselineVersion}", { + changing = true + } + baseline "io.rsocket:rsocket-transport-local:${perfBaselineVersion}", { + changing = true + } +} + +task jmhProfilers(type: JavaExec, description:'Lists the available profilers for the jmh task', group: 'Development') { + classpath = sourceSets.main.runtimeClasspath + main = 'org.openjdk.jmh.Main' + args '-lprof' +} + +task jmh(type: JmhExecTask, description: 'Executing JMH benchmarks') { + main = 'org.openjdk.jmh.Main' + classpath = sourceSets.main.runtimeClasspath + configurations.current +} + +task jmhBaseline(type: JmhExecTask, description: 'Executing JMH baseline benchmarks') { + main = 'org.openjdk.jmh.Main' + classpath = sourceSets.main.runtimeClasspath + configurations.baseline +} + +clean { + delete "${projectDir}/src/main/generated" +} + +class JmhExecTask extends JavaExec { + + private String include; + private String fullInclude; + private String exclude; + private String format = "json"; + private String profilers; + private String jmhJvmArgs; + private String verify; + + public JmhExecTask() { + super(); + } + + public String getInclude() { + return include; + } + + @Option(option = "include", description="configure bench inclusion using substring") + public void setInclude(String include) { + this.include = include; + } + + public String getFullInclude() { + return fullInclude; + } + + @Option(option = "fullInclude", description = "explicitly configure bench inclusion using full JMH style regexp") + public void setFullInclude(String fullInclude) { + this.fullInclude = fullInclude; + } + + public String getExclude() { + return exclude; + } + + @Option(option = "exclude", description = "explicitly configure bench exclusion using full JMH style regexp") + public void setExclude(String exclude) { + this.exclude = exclude; + } + + public String getFormat() { + return format; + } + + @Option(option = "format", description = "configure report format") + public void setFormat(String format) { + this.format = format; + } + + public String getProfilers() { + return profilers; + } + + @Option(option = "profilers", description = "configure jmh profiler(s) to use, comma separated") + public void setProfilers(String profilers) { + this.profilers = profilers; + } + + public String getJmhJvmArgs() { + return jmhJvmArgs; + } + + @Option(option = "jvmArgs", description = "configure additional JMH JVM arguments, comma separated") + public void setJmhJvmArgs(String jvmArgs) { + this.jmhJvmArgs = jvmArgs; + } + + public String getVerify() { + return verify; + } + + @Option(option = "verify", description = "run in verify mode") + public void setVerify(String verify) { + this.verify = verify; + } + + @TaskAction + public void exec() { + File resultFile = getProject().file("build/reports/" + getName() + "/result." + format); + + if (include != null) { + args(".*" + include + ".*"); + } + else if (fullInclude != null) { + args(fullInclude); + } + + if(exclude != null) { + args("-e", exclude); + } + if(verify != null) { // execute benchmarks with the minimum amount of execution (only to check if they are working) + System.out.println("Running in verify mode"); + args("-f", 1); + args("-wi", 1); + args("-i", 1); + } + args("-foe", "true"); //fail-on-error + args("-v", "NORMAL"); //verbosity [SILENT, NORMAL, EXTRA] + if(profilers != null) { + for (String prof : profilers.split(",")) { + args("-prof", prof); + } + } + args("-jvmArgsPrepend", "-Xmx3072m"); + args("-jvmArgsPrepend", "-Xms3072m"); + if(jmhJvmArgs != null) { + for(String jvmArg : jmhJvmArgs.split(" ")) { + args("-jvmArgsPrepend", jvmArg); + } + } + args("-rf", format); + args("-rff", resultFile); + + System.out.println("\nExecuting JMH with: " + getArgs() + "\n"); + resultFile.getParentFile().mkdirs(); + + super.exec(); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java new file mode 100644 index 000000000..2e6fa6acc --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/MaxPerfSubscriber.java @@ -0,0 +1,37 @@ +package io.rsocket; + +import java.util.concurrent.CountDownLatch; +import org.openjdk.jmh.infra.Blackhole; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +public class MaxPerfSubscriber extends CountDownLatch implements CoreSubscriber { + + final Blackhole blackhole; + + public MaxPerfSubscriber(Blackhole blackhole) { + super(1); + this.blackhole = blackhole; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(T payload) { + blackhole.consume(payload); + } + + @Override + public void onError(Throwable t) { + blackhole.consume(t); + countDown(); + } + + @Override + public void onComplete() { + countDown(); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java new file mode 100644 index 000000000..7a7a1fdd6 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsMaxPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsMaxPerfSubscriber extends MaxPerfSubscriber { + + public PayloadsMaxPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java new file mode 100644 index 000000000..efc116958 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PayloadsPerfSubscriber.java @@ -0,0 +1,16 @@ +package io.rsocket; + +import org.openjdk.jmh.infra.Blackhole; + +public class PayloadsPerfSubscriber extends PerfSubscriber { + + public PayloadsPerfSubscriber(Blackhole blackhole) { + super(blackhole); + } + + @Override + public void onNext(Payload payload) { + payload.release(); + super.onNext(payload); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java new file mode 100644 index 000000000..92577d95c --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/PerfSubscriber.java @@ -0,0 +1,41 @@ +package io.rsocket; + +import java.util.concurrent.CountDownLatch; +import org.openjdk.jmh.infra.Blackhole; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +public class PerfSubscriber extends CountDownLatch implements CoreSubscriber { + + final Blackhole blackhole; + + Subscription s; + + public PerfSubscriber(Blackhole blackhole) { + super(1); + this.blackhole = blackhole; + } + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + s.request(1); + } + + @Override + public void onNext(T payload) { + blackhole.consume(payload); + s.request(1); + } + + @Override + public void onError(Throwable t) { + blackhole.consume(t); + countDown(); + } + + @Override + public void onComplete() { + countDown(); + } +} diff --git a/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java new file mode 100644 index 000000000..4437400c4 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/core/RSocketPerf.java @@ -0,0 +1,226 @@ +package io.rsocket.core; + +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.PayloadsMaxPerfSubscriber; +import io.rsocket.PayloadsPerfSubscriber; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.lang.reflect.Field; +import java.util.Queue; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.LockSupport; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Level; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.TearDown; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; +import org.reactivestreams.Publisher; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +@BenchmarkMode({Mode.Throughput, Mode.SampleTime}) +@Fork(value = 2) +@Warmup(iterations = 10) +@Measurement(iterations = 10, time = 10) +@State(Scope.Benchmark) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +public class RSocketPerf { + + @Param({"tcp", "websocket", "local"}) + String transportType; + + @Param({"0", "64", "1024", "131072", "1048576", "15728640"}) + String payloadSize; + + Payload payload; + Mono payloadMono; + Flux payloadsFlux; + + RSocket client; + Closeable server; + Queue clientsQueue; + + @TearDown + public void tearDown() { + client.dispose(); + server.dispose(); + payload.release(); + } + + @TearDown(Level.Iteration) + public void awaitToBeConsumed() { + while (!clientsQueue.isEmpty()) { + LockSupport.parkNanos(1000); + } + } + + @Setup + public void setUp() throws NoSuchFieldException, IllegalAccessException, ClassNotFoundException { + ClientTransport clientTransport; + ServerTransport serverTransport; + switch (transportType) { + case "tcp": + clientTransport = TcpClientTransport.create(8081); + serverTransport = TcpServerTransport.create(8081); + break; + case "websocket": + clientTransport = WebsocketClientTransport.create(8081); + serverTransport = WebsocketServerTransport.create(8081); + break; + case "local": + default: + clientTransport = LocalClientTransport.create("server"); + serverTransport = LocalServerTransport.create("server"); + break; + } + Payload payload; + int payloadSize = Integer.parseInt(this.payloadSize); + if (payloadSize == 0) { + payload = EmptyPayload.INSTANCE; + } else { + byte[] randomMetadata = new byte[payloadSize / 2]; + byte[] randomData = new byte[payloadSize / 2]; + ThreadLocalRandom.current().nextBytes(randomData); + ThreadLocalRandom.current().nextBytes(randomMetadata); + + payload = ByteBufPayload.create(randomData, randomMetadata); + } + + this.payload = payload; + this.payloadMono = Mono.fromSupplier(payload::retain); + this.payloadsFlux = Flux.range(0, 100000).map(__ -> payload.retain()); + this.server = + RSocketServer.create( + (setup, sendingSocket) -> + Mono.just( + new RSocket() { + + @Override + public Mono fireAndForget(Payload payload) { + payload.release(); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return payloadMono; + } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return payloadsFlux; + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + })) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(serverTransport) + .block(); + + this.client = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(clientTransport) + .block(); + + try { + Field sendProcessorField = RSocketRequester.class.getDeclaredField("sendProcessor"); + sendProcessorField.setAccessible(true); + + clientsQueue = (Queue) sendProcessorField.get(client); + } catch (Throwable t) { + Field sendProcessorField = + Class.forName("io.rsocket.core.RequesterResponderSupport") + .getDeclaredField("sendProcessor"); + sendProcessorField.setAccessible(true); + + clientsQueue = (Queue) sendProcessorField.get(client); + } + } + + @Benchmark + @SuppressWarnings("unchecked") + public PayloadsPerfSubscriber fireAndForget(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.fireAndForget(payload.retain()).subscribe((CoreSubscriber) subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsPerfSubscriber requestResponse(Blackhole blackhole) throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.requestResponse(payload.retain()).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsPerfSubscriber requestStreamWithRequestByOneStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.requestStream(payload.retain()).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsMaxPerfSubscriber requestStreamWithRequestAllStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); + client.requestStream(payload.retain()).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsPerfSubscriber requestChannelWithRequestByOneStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsPerfSubscriber subscriber = new PayloadsPerfSubscriber(blackhole); + client.requestChannel(payloadsFlux).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } + + @Benchmark + public PayloadsMaxPerfSubscriber requestChannelWithRequestAllStrategy(Blackhole blackhole) + throws InterruptedException { + PayloadsMaxPerfSubscriber subscriber = new PayloadsMaxPerfSubscriber(blackhole); + client.requestChannel(payloadsFlux).subscribe(subscriber); + subscriber.await(); + + return subscriber; + } +} diff --git a/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderCodecPerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderCodecPerf.java new file mode 100644 index 000000000..402cdb353 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/frame/FrameHeaderCodecPerf.java @@ -0,0 +1,55 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork( + value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} + ) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class FrameHeaderCodecPerf { + + @Benchmark + public void encode(Input input) { + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(input.allocator, FrameType.SETUP, 0); + boolean release = byteBuf.release(); + input.bh.consume(release); + } + + @Benchmark + public void decode(Input input) { + ByteBuf frame = input.frame; + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); + int flags = FrameHeaderCodec.flags(frame); + input.bh.consume(streamId); + input.bh.consume(flags); + input.bh.consume(frameType); + } + + @State(Scope.Benchmark) + public static class Input { + Blackhole bh; + FrameType frameType; + ByteBufAllocator allocator; + ByteBuf frame; + + @Setup + public void setup(Blackhole bh) { + this.bh = bh; + this.frameType = FrameType.REQUEST_RESPONSE; + allocator = ByteBufAllocator.DEFAULT; + frame = FrameHeaderCodec.encode(allocator, 123, FrameType.SETUP, 0); + } + + @TearDown + public void teardown() { + frame.release(); + } + } +} diff --git a/benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java b/benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java new file mode 100644 index 000000000..efa22104f --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/frame/FrameTypePerf.java @@ -0,0 +1,38 @@ +package io.rsocket.frame; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork( + value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} + ) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class FrameTypePerf { + @Benchmark + public void lookup(Input input) { + FrameType frameType = input.frameType; + boolean b = + frameType.canHaveData() + && frameType.canHaveMetadata() + && frameType.isFragmentable() + && frameType.isRequestType() + && frameType.hasInitialRequestN(); + + input.bh.consume(b); + } + + @State(Scope.Benchmark) + public static class Input { + Blackhole bh; + FrameType frameType; + + @Setup + public void setup(Blackhole bh) { + this.bh = bh; + this.frameType = FrameType.REQUEST_RESPONSE; + } + } +} diff --git a/benchmarks/src/main/java/io/rsocket/frame/PayloadFrameCodecPerf.java b/benchmarks/src/main/java/io/rsocket/frame/PayloadFrameCodecPerf.java new file mode 100644 index 000000000..ead1c2fa3 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/frame/PayloadFrameCodecPerf.java @@ -0,0 +1,77 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork( + value = 1 // , jvmArgsAppend = {"-Dio.netty.leakDetection.level=advanced"} + ) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class PayloadFrameCodecPerf { + + @Benchmark + public void encode(Input input) { + ByteBuf encode = + PayloadFrameCodec.encode( + input.allocator, + 100, + false, + true, + false, + Unpooled.wrappedBuffer(input.metadata), + Unpooled.wrappedBuffer(input.data)); + boolean release = encode.release(); + input.bh.consume(release); + } + + @Benchmark + public void decode(Input input) { + ByteBuf frame = input.payload; + ByteBuf data = PayloadFrameCodec.data(frame); + ByteBuf metadata = PayloadFrameCodec.metadata(frame); + input.bh.consume(data); + input.bh.consume(metadata); + } + + @State(Scope.Benchmark) + public static class Input { + Blackhole bh; + FrameType frameType; + ByteBufAllocator allocator; + ByteBuf payload; + byte[] metadata = new byte[512]; + byte[] data = new byte[4096]; + + @Setup + public void setup(Blackhole bh) { + this.bh = bh; + this.frameType = FrameType.REQUEST_RESPONSE; + allocator = ByteBufAllocator.DEFAULT; + + // Encode a payload and then copy it a single bytebuf + payload = allocator.buffer(); + ByteBuf encode = + PayloadFrameCodec.encode( + allocator, + 100, + false, + true, + false, + Unpooled.wrappedBuffer(metadata), + Unpooled.wrappedBuffer(data)); + payload.writeBytes(encode); + encode.release(); + } + + @TearDown + public void teardown() { + payload.release(); + } + } +} diff --git a/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java b/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java new file mode 100644 index 000000000..8f429fc19 --- /dev/null +++ b/benchmarks/src/main/java/io/rsocket/metadata/WellKnownMimeTypePerf.java @@ -0,0 +1,96 @@ +package io.rsocket.metadata; + +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; + +@BenchmarkMode(Mode.Throughput) +@Fork(value = 1) +@Warmup(iterations = 10) +@Measurement(iterations = 10) +@State(Scope.Thread) +public class WellKnownMimeTypePerf { + + // this is the old values() looping implementation of fromIdentifier + private WellKnownMimeType fromIdValuesLoop(int id) { + if (id < 0 || id > 127) { + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE; + } + for (WellKnownMimeType value : WellKnownMimeType.values()) { + if (value.getIdentifier() == id) { + return value; + } + } + return WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE; + } + + // this is the core of the old values() looping implementation of fromString + private WellKnownMimeType fromStringValuesLoop(String mimeType) { + for (WellKnownMimeType value : WellKnownMimeType.values()) { + if (mimeType.equals(value.getString())) { + return value; + } + } + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE; + } + + @Benchmark + public void fromIdArrayLookup(final Blackhole bh) { + // negative lookup + bh.consume(WellKnownMimeType.fromIdentifier(-10)); + bh.consume(WellKnownMimeType.fromIdentifier(-1)); + // too large lookup + bh.consume(WellKnownMimeType.fromIdentifier(129)); + // first lookup + bh.consume(WellKnownMimeType.fromIdentifier(0)); + // middle lookup + bh.consume(WellKnownMimeType.fromIdentifier(37)); + // reserved lookup + bh.consume(WellKnownMimeType.fromIdentifier(63)); + // last lookup + bh.consume(WellKnownMimeType.fromIdentifier(127)); + } + + @Benchmark + public void fromIdValuesLoopLookup(final Blackhole bh) { + // negative lookup + bh.consume(fromIdValuesLoop(-10)); + bh.consume(fromIdValuesLoop(-1)); + // too large lookup + bh.consume(fromIdValuesLoop(129)); + // first lookup + bh.consume(fromIdValuesLoop(0)); + // middle lookup + bh.consume(fromIdValuesLoop(37)); + // reserved lookup + bh.consume(fromIdValuesLoop(63)); + // last lookup + bh.consume(fromIdValuesLoop(127)); + } + + @Benchmark + public void fromStringMapLookup(final Blackhole bh) { + // unknown lookup + bh.consume(WellKnownMimeType.fromString("foo/bar")); + // first lookup + bh.consume(WellKnownMimeType.fromString(WellKnownMimeType.APPLICATION_AVRO.getString())); + // middle lookup + bh.consume(WellKnownMimeType.fromString(WellKnownMimeType.VIDEO_VP8.getString())); + // last lookup + bh.consume( + WellKnownMimeType.fromString( + WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString())); + } + + @Benchmark + public void fromStringValuesLoopLookup(final Blackhole bh) { + // unknown lookup + bh.consume(fromStringValuesLoop("foo/bar")); + // first lookup + bh.consume(fromStringValuesLoop(WellKnownMimeType.APPLICATION_AVRO.getString())); + // middle lookup + bh.consume(fromStringValuesLoop(WellKnownMimeType.VIDEO_VP8.getString())); + // last lookup + bh.consume( + fromStringValuesLoop(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString())); + } +} diff --git a/build.gradle b/build.gradle index 8fbcaa6c4..2971a7767 100644 --- a/build.gradle +++ b/build.gradle @@ -1,44 +1,290 @@ -buildscript { - repositories { - jcenter() - } +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ - dependencies { classpath 'io.reactivesocket:gradle-nebula-plugin-reactivesocket:1.0.5' } +plugins { + id 'com.github.sherter.google-java-format' version '0.9' apply false + id 'me.champeau.jmh' version '0.7.1' apply false + id 'io.spring.dependency-management' version '1.1.0' apply false + id 'io.morethan.jmhreport' version '0.9.0' apply false + id 'io.github.reyerizo.gradle.jcstress' version '0.8.15' apply false + id 'com.github.vlsi.gradle-extensions' version '1.89' apply false } -description = 'ReactiveSocket: stream oriented messaging passing with Reactive Stream semantics.' +boolean isCiServer = ["CI", "CONTINUOUS_INTEGRATION", "TRAVIS", "CIRCLECI", "bamboo_planKey", "GITHUB_ACTION"].with { + retainAll(System.getenv().keySet()) + return !isEmpty() +} -apply plugin: 'reactivesocket-project' -apply plugin: 'java' +subprojects { + apply plugin: 'io.spring.dependency-management' + apply plugin: 'com.github.sherter.google-java-format' + apply plugin: 'com.github.vlsi.gradle-extensions' -repositories { - maven { url 'https://oss.jfrog.org/libs-snapshot' } + ext['reactor-bom.version'] = '2022.0.7-SNAPSHOT' + ext['logback.version'] = '1.2.13' + ext['netty-bom.version'] = '4.1.117.Final' + ext['netty-boringssl.version'] = '2.0.69.Final' + ext['hdrhistogram.version'] = '2.1.12' + ext['mockito.version'] = '4.11.0' + ext['slf4j.version'] = '1.7.36' + ext['jmh.version'] = '1.36' + ext['junit.version'] = '5.9.3' + ext['micrometer.version'] = '1.11.12' + ext['micrometer-tracing.version'] = '1.1.13' + ext['assertj.version'] = '3.24.2' + ext['netflix.limits.version'] = '0.3.6' + ext['bouncycastle-bcpkix.version'] = '1.70' + ext['awaitility.version'] = '4.2.0' + + group = "io.rsocket" + + googleJavaFormat { + toolVersion = '1.6' + } + + ext { + if (project.hasProperty('versionSuffix')) { + project.version += project.getProperty('versionSuffix') + } + } + + configurations.all { + resolutionStrategy.cacheChangingModulesFor 60, "minutes" + } + + dependencyManagement { + imports { + mavenBom "io.projectreactor:reactor-bom:${ext['reactor-bom.version']}" + mavenBom "io.netty:netty-bom:${ext['netty-bom.version']}" + mavenBom "org.junit:junit-bom:${ext['junit.version']}" + mavenBom "io.micrometer:micrometer-bom:${ext['micrometer.version']}" + mavenBom "io.micrometer:micrometer-tracing-bom:${ext['micrometer-tracing.version']}" + } + + dependencies { + dependency "com.netflix.concurrency-limits:concurrency-limits-core:${ext['netflix.limits.version']}" + dependency "ch.qos.logback:logback-classic:${ext['logback.version']}" + dependency "io.netty:netty-tcnative-boringssl-static:${ext['netty-boringssl.version']}" + dependency "org.bouncycastle:bcpkix-jdk15on:${ext['bouncycastle-bcpkix.version']}" + dependency "org.assertj:assertj-core:${ext['assertj.version']}" + dependency "org.hdrhistogram:HdrHistogram:${ext['hdrhistogram.version']}" + dependency "org.slf4j:slf4j-api:${ext['slf4j.version']}" + dependency "org.awaitility:awaitility:${ext['awaitility.version']}" + dependencySet(group: 'org.mockito', version: ext['mockito.version']) { + entry 'mockito-junit-jupiter' + entry 'mockito-core' + } + dependencySet(group: 'org.openjdk.jmh', version: ext['jmh.version']) { + entry 'jmh-core' + entry 'jmh-generator-annprocess' + } + } + generatedPomCustomization { + enabled = false + } + } + + repositories { + mavenCentral() + + maven { + url 'https://repo.spring.io/milestone' + content { + includeGroup "io.micrometer" + includeGroup "io.projectreactor" + includeGroup "io.projectreactor.netty" + includeGroup "io.micrometer" + } + } + + maven { + url 'https://repo.spring.io/snapshot' + content { + includeGroup "io.micrometer" + includeGroup "io.projectreactor" + includeGroup "io.projectreactor.netty" + } + } + + if (version.endsWith('SNAPSHOT') || project.hasProperty('versionSuffix')) { + maven { url 'https://repo.spring.io/libs-snapshot' } + maven { url 'https://oss.jfrog.org/artifactory/oss-snapshot-local' } + mavenLocal() + } + } + + tasks.withType(GenerateModuleMetadata) { + enabled = false + } + + plugins.withType(JavaPlugin) { + + compileJava { + sourceCompatibility = 1.8 + + // TODO: Cleanup warnings so no need to exclude + options.compilerArgs << '-Xlint:all,-overloads,-rawtypes,-unchecked' + } + + javadoc { + def jdk = JavaVersion.current().majorVersion + def jdkJavadoc = "https://docs.oracle.com/javase/$jdk/docs/api/" + if (JavaVersion.current().isJava11Compatible()) { + jdkJavadoc = "https://docs.oracle.com/en/java/javase/$jdk/docs/api/" + } + options.with { + links jdkJavadoc + links 'https://projectreactor.io/docs/core/release/api/' + links 'https://netty.io/4.1/api/' + } + failOnError = false + } + + tasks.named("javadoc").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } + + test { + useJUnitPlatform() + testLogging { + events "PASSED", "FAILED" + showExceptions true + showCauses true + exceptionFormat "FULL" + stackTraceFilters "ENTRY_POINT" + maxGranularity 3 + } + + //show progress by displaying test classes, avoiding test suite timeouts + TestDescriptor last + afterTest { TestDescriptor td, TestResult tr -> + if (last != td.getParent()) { + last = td.getParent() + println last + } + } + + if (isCiServer) { + def stdout = new LinkedList() + beforeTest { TestDescriptor td -> + stdout.clear() + } + onOutput { TestDescriptor td, TestOutputEvent toe -> + stdout.add(toe) + } + afterTest { TestDescriptor td, TestResult tr -> + if (tr.resultType == TestResult.ResultType.FAILURE && stdout.size() > 0) { + def stdOutput = stdout.collect { + it.getDestination().name() == "StdErr" + ? "STD_ERR: ${it.getMessage()}" + : "STD_OUT: ${it.getMessage()}" + } + .join() + println "This is the console output of the failing test below:\n$stdOutput" + } + } + + reports { + junitXml.outputPerTestCase = true + } + } + + if (JavaVersion.current().isJava9Compatible()) { + println "Java 9+: lowering MaxGCPauseMillis to 20ms in ${project.name} ${name}" + println "Java 9+: enabling leak detection [ADVANCED]" + jvmArgs = ["-XX:MaxGCPauseMillis=20", "-Dio.netty.leakDetection.level=ADVANCED", "-Dio.netty.leakDetection.samplingInterval=32"] + } + + systemProperty("java.awt.headless", "true") + systemProperty("testGroups", project.properties.get("testGroups")) + + //allow re-run of failed tests only without special test tasks failing + // because the filter is too restrictive + filter.setFailOnNoMatchingTests(false) + + //display intermediate results for special test tasks + afterSuite { desc, result -> + if (!desc.parent) { // will match the outermost suite + println('\n' + "${desc} Results: ${result.resultType} (${result.testCount} tests, ${result.successfulTestCount} successes, ${result.failedTestCount} failures, ${result.skippedTestCount} skipped)") + } + } + } + } + + plugins.withType(JavaLibraryPlugin) { + task sourcesJar(type: Jar) { + classifier 'sources' + from sourceSets.main.allJava + } + + task javadocJar(type: Jar, dependsOn: javadoc) { + classifier 'javadoc' + from javadoc.destinationDir + } + + plugins.withType(MavenPublishPlugin) { + publishing { + publications { + maven(MavenPublication) { + from components.java + artifact sourcesJar + artifact javadocJar + } + } + } + } + } } -dependencies { - compile 'org.reactivestreams:reactive-streams:1.0.0.final' - compile 'org.agrona:Agrona:0.4.13' +apply from: "${rootDir}/gradle/publications.gradle" - testCompile 'io.reactivex:rxjava:2.0.0-DP0-20151003.214425-143' - testCompile 'junit:junit:4.12' - testCompile 'org.mockito:mockito-core:1.10.19' +buildScan { + termsOfServiceUrl = 'https://gradle.com/terms-of-service' + termsOfServiceAgree = 'yes' } -// support for snapshot/final releases via versioned branch names like 1.x -nebulaRelease { - addReleaseBranchPattern(/\d+\.\d+\.\d+/) - addReleaseBranchPattern('HEAD') +description = 'RSocket: Stream Oriented Messaging Passing with Reactive Stream Semantics.' + +repositories { + mavenCentral() + + maven { url 'https://repo.spring.io/snapshot' } + mavenLocal() } -if (project.hasProperty('release.useLastTag')) { - tasks.prepare.enabled = false +configurations { + adoc } -test { - testLogging.showStandardStreams = true +dependencies { + adoc "io.micrometer:micrometer-docs-generator-spans:1.0.0-SNAPSHOT" + adoc "io.micrometer:micrometer-docs-generator-metrics:1.0.0-SNAPSHOT" } -compileJava { - sourceCompatibility = 1.8 - targetCompatibility = 1.8 -} \ No newline at end of file +task generateObservabilityDocs(dependsOn: ["generateObservabilityMetricsDocs", "generateObservabilitySpansDocs"]) { +} + +task generateObservabilityMetricsDocs(type: JavaExec) { + mainClass = "io.micrometer.docs.metrics.DocsFromSources" + classpath configurations.adoc + args project.rootDir.getAbsolutePath(), ".*", project.rootProject.buildDir.getAbsolutePath() +} + +task generateObservabilitySpansDocs(type: JavaExec) { + mainClass = "io.micrometer.docs.spans.DocsFromSources" + classpath configurations.adoc + args project.rootDir.getAbsolutePath(), ".*", project.rootProject.buildDir.getAbsolutePath() +} diff --git a/gradle.properties b/gradle.properties index ef6032984..d138852c5 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1 +1,15 @@ -release.scope=patch +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +version=1.2.0 +perfBaselineVersion=1.1.4 diff --git a/gradle/buildViaTravis.sh b/gradle/buildViaTravis.sh deleted file mode 100755 index d98e5eb60..000000000 --- a/gradle/buildViaTravis.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -# This script will build the project. - -if [ "$TRAVIS_PULL_REQUEST" != "false" ]; then - echo -e "Build Pull Request #$TRAVIS_PULL_REQUEST => Branch [$TRAVIS_BRANCH]" - ./gradlew -Prelease.useLastTag=true build -elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" == "" ]; then - echo -e 'Build Branch with Snapshot => Branch ['$TRAVIS_BRANCH']' - ./gradlew -Prelease.travisci=true -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" build snapshot --stacktrace -elif [ "$TRAVIS_PULL_REQUEST" == "false" ] && [ "$TRAVIS_TAG" != "" ]; then - echo -e 'Build Branch for Release => Branch ['$TRAVIS_BRANCH'] Tag ['$TRAVIS_TAG']' - ./gradlew -Prelease.travisci=true -Prelease.useLastTag=true -PbintrayUser="${bintrayUser}" -PbintrayKey="${bintrayKey}" -PsonatypeUsername="${sonatypeUsername}" -PsonatypePassword="${sonatypePassword}" final --stacktrace -else - echo -e 'WARN: Should not be here => Branch ['$TRAVIS_BRANCH'] Tag ['$TRAVIS_TAG'] Pull Request ['$TRAVIS_PULL_REQUEST']' - ./gradlew -Prelease.useLastTag=true build -fi diff --git a/gradle/github-pkg.gradle b/gradle/github-pkg.gradle new file mode 100644 index 000000000..f53413766 --- /dev/null +++ b/gradle/github-pkg.gradle @@ -0,0 +1,21 @@ +subprojects { + + plugins.withType(MavenPublishPlugin) { + publishing { + repositories { + maven { + name = "GitHubPackages" + url = uri("https://maven.pkg.github.com/rsocket/rsocket-java") + credentials { + username = project.findProperty("gpr.user") ?: System.getenv("GITHUB_ACTOR") + password = project.findProperty("gpr.key") ?: System.getenv("GITHUB_TOKEN") + } + } + } + } + + tasks.named("publish").configure { + onlyIf { System.getenv('SKIP_RELEASE') != "true" } + } + } +} \ No newline at end of file diff --git a/gradle/publications.gradle b/gradle/publications.gradle new file mode 100644 index 000000000..9e8dd6d88 --- /dev/null +++ b/gradle/publications.gradle @@ -0,0 +1,53 @@ +apply from: "${rootDir}/gradle/github-pkg.gradle" +apply from: "${rootDir}/gradle/sonotype.gradle" + +subprojects { + plugins.withType(MavenPublishPlugin) { + publishing { + publications { + maven(MavenPublication) { + pom { + name = project.name + afterEvaluate { + description = project.description + } + groupId = 'io.rsocket' + url = 'http://rsocket.io' + licenses { + license { + name = "The Apache Software License, Version 2.0" + url = "https://www.apache.org/licenses/LICENSE-2.0.txt" + distribution = "repo" + } + } + developers { + developer { + id = 'OlegDokuka' + name = 'Oleh Dokuka' + email = 'oleh.dokuka@icloud.com' + } + developer { + id = 'rstoyanchev' + name = 'Rossen Stoyanchev' + email = 'rstoyanchev@vmware.com' + } + } + scm { + connection = 'scm:git:https://github.com/rsocket/rsocket-java.git' + developerConnection = 'scm:git:https://github.com/rsocket/rsocket-java.git' + url = 'https://github.com/rsocket/rsocket-java' + } + versionMapping { + usage('java-api') { + fromResolutionResult() + } + usage('java-runtime') { + fromResolutionResult() + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/gradle/sonotype.gradle b/gradle/sonotype.gradle new file mode 100644 index 000000000..f339079b0 --- /dev/null +++ b/gradle/sonotype.gradle @@ -0,0 +1,36 @@ +subprojects { + if (project.hasProperty('sonatypeUsername') && project.hasProperty('sonatypePassword')) { + plugins.withType(MavenPublishPlugin) { + plugins.withType(SigningPlugin) { + + signing { + //requiring signature if there is a publish task that is not to MavenLocal + required { gradle.taskGraph.allTasks.any { it.name.toLowerCase().contains("publish") && !it.name.contains("MavenLocal") } } + def signingKey = project.findProperty("signingKey") + def signingPassword = project.findProperty("signingPassword") + + useInMemoryPgpKeys(signingKey, signingPassword) + + afterEvaluate { + sign publishing.publications.maven + } + } + + publishing { + repositories { + maven { + name = "sonatype" + url = project.version.contains("-SNAPSHOT") + ? "https://oss.sonatype.org/content/repositories/snapshots/" + : "https://oss.sonatype.org/service/local/staging/deploy/maven2" + credentials { + username project.findProperty("sonatypeUsername") + password project.findProperty("sonatypePassword") + } + } + } + } + } + } + } +} \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 5ccda13e9..249e5832f 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index f1df5b75c..774fae876 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,5 @@ -#Mon Mar 07 16:10:12 PST 2016 distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-7.6.1-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-2.11-all.zip diff --git a/gradlew b/gradlew index 9d82f7891..a69d9cb6c 100755 --- a/gradlew +++ b/gradlew @@ -1,74 +1,129 @@ -#!/usr/bin/env bash +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## -# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -DEFAULT_JVM_OPTS="" +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` +APP_BASE_NAME=${0##*/} + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum -warn ( ) { +warn () { echo "$*" -} +} >&2 -die ( ) { +die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MINGW* ) - msys=true - ;; +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac -# Attempt to set APP_HOME -# Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi -done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null - CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + # Determine the Java command to use to start the JVM. if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -77,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" + JAVACMD=java which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the @@ -85,76 +140,101 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac fi -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) -# For Cygwin, switch paths to Windows format before running java -if $cygwin ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) fi - i=$((i+1)) + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg done - case $i in - (0) set -- ;; - (1) set -- "$args0" ;; - (2) set -- "$args0" "$args1" ;; - (3) set -- "$args0" "$args1" "$args2" ;; - (4) set -- "$args0" "$args1" "$args2" "$args3" ;; - (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; - esac fi -# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules -function splitJvmOpts() { - JVM_OPTS=("$@") -} -eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS -JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi -exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index 5f192121e..53a6b238d 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,90 +1,91 @@ -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS= - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto init - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args -if "%@eval[2+2]" == "4" goto 4NT_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* -goto execute - -:4NT_args -@rem Get arguments from the 4NT Shell from JP Software -set CMD_LINE_ARGS=%$ - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/rsocket-bom/build.gradle b/rsocket-bom/build.gradle new file mode 100755 index 000000000..a75ab3bc8 --- /dev/null +++ b/rsocket-bom/build.gradle @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id 'java-platform' + id 'maven-publish' + id 'signing' +} + +description = 'RSocket Java Bill of materials.' + +def excluded = ["rsocket-examples", "benchmarks"] + +dependencies { + constraints { + parent.subprojects.findAll { it.name != project.name && !excluded.contains(it.name) } .sort { "$it.name" }.each { + api it + } + } +} + +publishing { + publications { + maven(MavenPublication) { + from components.javaPlatform + } + } +} \ No newline at end of file diff --git a/rsocket-core/build.gradle b/rsocket-core/build.gradle new file mode 100644 index 000000000..da5b69b14 --- /dev/null +++ b/rsocket-core/build.gradle @@ -0,0 +1,60 @@ +/* + * Copyright 2015-2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' + id 'io.morethan.jmhreport' + id 'me.champeau.jmh' + id 'io.github.reyerizo.gradle.jcstress' +} + +dependencies { + api 'io.netty:netty-buffer' + api 'io.projectreactor:reactor-core' + + implementation 'org.slf4j:slf4j-api' + + testImplementation (project(":rsocket-transport-local")) + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.assertj:assertj-core' + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.junit.jupiter:junit-jupiter-params' + testImplementation 'org.mockito:mockito-junit-jupiter' + testImplementation 'org.awaitility:awaitility' + + testRuntimeOnly 'ch.qos.logback:logback-classic' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' + + jcstressImplementation(project(":rsocket-test")) + jcstressImplementation 'org.slf4j:slf4j-api' + jcstressImplementation "ch.qos.logback:logback-classic" + jcstressImplementation 'io.projectreactor:reactor-test' +} + +jcstress { + mode = 'sanity' //sanity, quick, default, tough + jcstressDependency = "org.openjdk.jcstress:jcstress-core:0.16" +} + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.core") + } +} + +description = "Core functionality for the RSocket library" diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java new file mode 100644 index 000000000..e91be2451 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/FireAndForgetRequesterMonoStressTest.java @@ -0,0 +1,115 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.rsocket.test.TestDuplexConnection; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLL_Result; + +public abstract class FireAndForgetRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final TestRequesterResponderSupport requesterResponderSupport = + new TestRequesterResponderSupport(testDuplexConnection, StreamIdSupplier.clientSupplier()); + + final FireAndForgetRequesterMono source = source(); + + abstract FireAndForgetRequesterMono source(); + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + @Override + FireAndForgetRequesterMono source() { + return new FireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndCancelRaceStressTest extends BaseStressTest { + + @Override + FireAndForgetRequesterMono source() { + return new FireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java new file mode 100644 index 000000000..ef79d344d --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/ReconnectMonoStressTest.java @@ -0,0 +1,604 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.ResolvingOperator.EMPTY_SUBSCRIBED; +import static io.rsocket.core.ResolvingOperator.EMPTY_UNSUBSCRIBED; +import static io.rsocket.core.ResolvingOperator.READY; +import static io.rsocket.core.ResolvingOperator.TERMINATED; +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.function.BiConsumer; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.IIIIIII_Result; +import org.openjdk.jcstress.infra.results.IIIIII_Result; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; + +public abstract class ReconnectMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscription stressSubscription = new StressSubscription<>(); + + final Mono source = source(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + volatile int onValueExpire; + + static final AtomicIntegerFieldUpdater ON_VALUE_EXPIRE = + AtomicIntegerFieldUpdater.newUpdater(BaseStressTest.class, "onValueExpire"); + + volatile int onValueReceived; + + static final AtomicIntegerFieldUpdater ON_VALUE_RECEIVED = + AtomicIntegerFieldUpdater.newUpdater(BaseStressTest.class, "onValueReceived"); + final ReconnectMono reconnectMono = + new ReconnectMono<>( + source, + (__) -> ON_VALUE_EXPIRE.incrementAndGet(BaseStressTest.this), + (__, ___) -> ON_VALUE_RECEIVED.incrementAndGet(BaseStressTest.this)); + + abstract Mono source(); + + int state() { + final BiConsumer[] subscribers = reconnectMono.resolvingInner.subscribers; + if (subscribers == EMPTY_UNSUBSCRIBED) { + return 0; + } else if (subscribers == EMPTY_SUBSCRIBED) { + return 1; + } else if (subscribers == READY) { + return 2; + } else if (subscribers == TERMINATED) { + return 3; + } else { + return 4; + } + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed before value is delivered") + @Outcome( + id = {"0, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after onComplete but before value is delivered") + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after value is delivered") + @State + public static class ExpireValueOnRacingDisposeAndNext extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed before error is delivered") + @Outcome( + id = {"0, 0, 0, 1, 1, 0, 3"}, + expect = ACCEPTABLE, + desc = "Disposed after onError") + @State + public static class ExpireValueOnRacingDisposeAndError extends BaseStressTest { + + { + Hooks.onErrorDropped(t -> {}); + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onError(new RuntimeException("boom")); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + Hooks.resetOnErrorDropped(); + + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 0, 1, 2"}, + expect = ACCEPTABLE, + desc = "Invalidate happens before value is delivered") + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Invalidate happens after value is delivered") + @State + public static class ExpireValueOnRacingInvalidateAndNextComplete extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNext() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 0"}, + expect = ACCEPTABLE) + @State + public static class ExpireValueOnceOnRacingInvalidateAndInvalidate extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + }; + } + + @Actor + void invalidate1() { + reconnectMono.invalidate(); + } + + @Actor + void invalidate2() { + reconnectMono.invalidate(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"0, 1, 1, 0, 1, 1, 3"}, + expect = ACCEPTABLE) + @State + public static class ExpireValueOnceOnRacingInvalidateAndDispose extends BaseStressTest { + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void dispose() { + reconnectMono.dispose(); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.cancelled ? 1 : 0; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = stressSubscriber.onCompleteCalls; + r.r4 = stressSubscriber.onErrorCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 2, 2, 0, 1"}, + expect = ACCEPTABLE) + @State + public static class DeliversValueToAllSubscribersUnderRace extends BaseStressTest { + + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void sendNextAndComplete() { + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Actor + void secondSubscribe() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.requestsCount; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.onNextCalls + stressSubscriber2.onNextCalls; + r.r4 = stressSubscriber.onCompleteCalls + stressSubscriber2.onCompleteCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + } + } + + @JCStressTest + @Outcome( + id = {"2, 0, 1, 1, 1, 1, 4"}, + expect = ACCEPTABLE, + desc = "Second Subscriber subscribed after invalidate") + @Outcome( + id = {"1, 0, 2, 2, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Second Subscriber subscribed before invalidate and received value") + @State + public static class InvalidateAndSubscribeUnderRace extends BaseStressTest { + + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + { + reconnectMono.subscribe(stressSubscriber); + stressSubscription.actual.onNext("value"); + stressSubscription.actual.onComplete(); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void secondSubscribe() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.onNextCalls + stressSubscriber2.onNextCalls; + r.r4 = stressSubscriber.onCompleteCalls + stressSubscriber2.onCompleteCalls; + r.r5 = onValueExpire; + r.r6 = onValueReceived; + r.r7 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"2, 0, 2, 1, 2, 2"}, + expect = ACCEPTABLE, + desc = "Subscribed again after invalidate") + @Outcome( + id = {"1, 0, 1, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "Subscribed before invalidate") + @State + public static class InvalidateAndBlockUnderRace extends BaseStressTest { + + String receivedValue; + + { + reconnectMono.subscribe(stressSubscriber); + } + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void invalidate() { + reconnectMono.invalidate(); + } + + @Actor + void secondSubscribe() { + receivedValue = reconnectMono.block(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue.equals("value1") ? 1 : receivedValue.equals("value2") ? 2 : -1; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRace extends BaseStressTest { + + StressSubscriber stressSubscriber2 = new StressSubscriber<>(); + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void subscribe1() { + reconnectMono.subscribe(stressSubscriber); + } + + @Actor + void subscribe2() { + reconnectMono.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = stressSubscriber.values.get(0).equals(stressSubscriber2.values.get(0)) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class SubscribeBlockConnectRace extends BaseStressTest { + + String receivedValue; + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void block() { + receivedValue = reconnectMono.block(); + } + + @Actor + void subscribe() { + reconnectMono.subscribe(stressSubscriber); + } + + @Actor + void connect() { + reconnectMono.resolvingInner.connect(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue.equals(stressSubscriber.values.get(0)) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } + + @JCStressTest + @Outcome( + id = {"1, 0, 1, 0, 1, 2"}, + expect = ACCEPTABLE) + @State + public static class TwoBlocksRace extends BaseStressTest { + + String receivedValue1; + String receivedValue2; + + @Override + Mono source() { + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + stressSubscription.subscribe(actual); + actual.onNext("value" + stressSubscription.subscribes); + actual.onComplete(); + } + }; + } + + @Actor + void block1() { + receivedValue1 = reconnectMono.block(); + } + + @Actor + void block2() { + receivedValue2 = reconnectMono.block(); + } + + @Arbiter + public void arbiter(IIIIII_Result r) { + r.r1 = stressSubscription.subscribes; + r.r2 = stressSubscription.cancelled ? 1 : 0; + r.r3 = receivedValue1.equals(receivedValue2) ? 1 : 2; + r.r4 = onValueExpire; + r.r5 = onValueReceived; + r.r6 = state(); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java new file mode 100644 index 000000000..1dde77b34 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/RequestResponseRequesterMonoStressTest.java @@ -0,0 +1,650 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.test.TestDuplexConnection; +import java.util.stream.IntStream; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLLLL_Result; +import org.openjdk.jcstress.infra.results.LLLLL_Result; +import org.openjdk.jcstress.infra.results.LLLL_Result; + +public abstract class RequestResponseRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(initialRequest()); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final RequesterLeaseTracker requesterLeaseTracker; + + final TestRequesterResponderSupport requesterResponderSupport; + + final RequestResponseRequesterMono source; + + BaseStressTest(RequesterLeaseTracker requesterLeaseTracker) { + this.requesterLeaseTracker = requesterLeaseTracker; + this.requesterResponderSupport = + new TestRequesterResponderSupport( + testDuplexConnection, StreamIdSupplier.clientSupplier(), requesterLeaseTracker); + this.source = source(); + } + + abstract RequestResponseRequesterMono source(); + + abstract long initialRequest(); + } + + abstract static class BaseStressTestWithLease extends BaseStressTest { + + BaseStressTestWithLease(int maximumAllowedAwaitingPermitHandlersNumber) { + super(new RequesterLeaseTracker("test", maximumAllowedAwaitingPermitHandlersNumber)); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTestWithLease { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + public TwoSubscribesRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return Long.MAX_VALUE; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + final ByteBuf nextFrame = + PayloadFrameCodec.encode( + this.testDuplexConnection.alloc(), + 1, + false, + true, + true, + null, + ByteBufUtil.writeUtf8(this.testDuplexConnection.alloc(), "response-data")); + this.source.handleNext(nextFrame, false, true); + nextFrame.release(); + + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + this.stressSubscriber1.values.forEach(Payload::release); + + r.r5 = this.source.payload.refCnt() + nextFrame.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancelRaceStressTest extends BaseStressTestWithLease { + + public SubscribeAndRequestAndCancelRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancelWithDeferredLeaseRaceStressTest + extends BaseStressTestWithLease { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + public SubscribeAndRequestAndCancelWithDeferredLeaseRaceStressTest() { + super(1); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 2, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "NoLeaseError delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @Outcome( + id = {"-9223372036854775808, 3, 0, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = + "cancellation happened after lease permit requested but before it was actually decided and in the case when no lease are available. Error is dropped") + @State + public static class SubscribeAndRequestAndCancelWithDeferredLease2RaceStressTest + extends BaseStressTestWithLease { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + SubscribeAndRequestAndCancelWithDeferredLease2RaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + r.r6 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 0, 2, 0, " + (0x04 + 2 * 0x09)}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndRequestAndCancel extends BaseStressTest { + + SubscribeAndRequestAndCancel() { + super(null); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + long initialRequest() { + return 0; + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Actor + public void request() { + this.stressSubscriber.request(1); + this.stressSubscriber.request(Long.MAX_VALUE); + this.stressSubscriber.request(1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.source.payload.refCnt(); + r.r5 = + IntStream.range(0, this.outboundSubscriber.values.size()) + .map( + i -> + FrameHeaderCodec.frameType(this.outboundSubscriber.values.get(i)) + .getEncodedType() + * (i + 1)) + .sum(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @State + public static class CancelWithInboundNextRaceStressTest extends BaseStressTestWithLease { + + final ByteBuf nextFrame = + PayloadFrameCodec.encode( + this.testDuplexConnection.alloc(), + 1, + false, + true, + true, + null, + ByteBufUtil.writeUtf8(this.testDuplexConnection.alloc(), "response-data")); + + CancelWithInboundNextRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundNext() { + this.source.handleNext(this.nextFrame, false, true); + this.nextFrame.release(); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt() + this.nextFrame.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first or in between") + @State + public static class CancelWithInboundCompleteRaceStressTest extends BaseStressTestWithLease { + + CancelWithInboundCompleteRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundComplete() { + this.source.handleComplete(); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 2, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 3, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first. inbound error dropped") + @State + public static class CancelWithInboundErrorRaceStressTest extends BaseStressTestWithLease { + + static final RuntimeException ERROR = new RuntimeException("Test"); + + CancelWithInboundErrorRaceStressTest() { + super(0); + } + + @Override + RequestResponseRequesterMono source() { + return new RequestResponseRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + + this.source.subscribe(this.stressSubscriber); + } + + @Override + long initialRequest() { + return 1; + } + + @Actor + public void inboundError() { + this.source.handleError(ERROR); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.stressSubscriber.onNextCalls; + + this.outboundSubscriber.values.forEach(ByteBuf::release); + this.stressSubscriber.values.forEach(Payload::release); + + r.r4 = this.source.payload.refCnt(); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java new file mode 100644 index 000000000..5de7eb4b9 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/SlowFireAndForgetRequesterMonoStressTest.java @@ -0,0 +1,288 @@ +package io.rsocket.core; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.test.TestDuplexConnection; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLLL_Result; + +public abstract class SlowFireAndForgetRequesterMonoStressTest { + + abstract static class BaseStressTest { + + final StressSubscriber outboundSubscriber = new StressSubscriber<>(); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(); + + final TestDuplexConnection testDuplexConnection = + new TestDuplexConnection(this.outboundSubscriber, false); + + final RequesterLeaseTracker requesterLeaseTracker = + new RequesterLeaseTracker("test", maximumAllowedAwaitingPermitHandlersNumber()); + + final TestRequesterResponderSupport requesterResponderSupport = + new TestRequesterResponderSupport( + testDuplexConnection, StreamIdSupplier.clientSupplier(), requesterLeaseTracker); + + final SlowFireAndForgetRequesterMono source = source(); + + abstract SlowFireAndForgetRequesterMono source(); + + abstract int maximumAllowedAwaitingPermitHandlersNumber(); + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 3, 1, 0, 0"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends BaseStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe1() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void subscribe2() { + this.source.subscribe(this.stressSubscriber1); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber1.onCompleteCalls + + this.stressSubscriber1.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @State + public static class SubscribeAndCancelRaceStressTest extends BaseStressTest { + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + // init + { + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @State + public static class SubscribeAndCancelWithDeferredLeaseRaceStressTest extends BaseStressTest { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 1; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = this.stressSubscriber.onCompleteCalls + this.stressSubscriber.onErrorCalls * 2; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } + + @JCStressTest + @Outcome( + id = {"-9223372036854775808, 1, 1, 0, 0"}, + expect = ACCEPTABLE, + desc = "frame delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 2, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "no lease error delivered before cancellation") + @Outcome( + id = {"-9223372036854775808, 0, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened first") + @Outcome( + id = {"-9223372036854775808, 0, 0, 0, 0"}, + expect = ACCEPTABLE, + desc = "cancellation happened in between") + @Outcome( + id = {"-9223372036854775808, 3, 0, 1, 0"}, + expect = ACCEPTABLE, + desc = + "cancellation happened after lease permit requested but before it was actually decided and in the case when no lease are available. Error is dropped") + @State + public static class SubscribeAndCancelWithDeferredLease2RaceStressTest extends BaseStressTest { + + final ByteBuf leaseFrame = + LeaseFrameCodec.encode(this.testDuplexConnection.alloc(), 1000, 1, null); + + @Override + SlowFireAndForgetRequesterMono source() { + return new SlowFireAndForgetRequesterMono( + UnpooledByteBufPayload.create( + "test-data", "test-metadata", this.requesterResponderSupport.getAllocator()), + this.requesterResponderSupport); + } + + @Override + int maximumAllowedAwaitingPermitHandlersNumber() { + return 0; + } + + @Actor + public void issueLease() { + final ByteBuf leaseFrame = this.leaseFrame; + this.requesterLeaseTracker.handleLeaseFrame(leaseFrame); + leaseFrame.release(); + } + + @Actor + public void subscribe() { + this.source.subscribe(this.stressSubscriber); + } + + @Actor + public void cancel() { + this.stressSubscriber.cancel(); + } + + @Arbiter + public void arbiter(LLLLL_Result r) { + r.r1 = this.source.state; + r.r2 = + this.stressSubscriber.onCompleteCalls + + this.stressSubscriber.onErrorCalls * 2 + + this.stressSubscriber.droppedErrors.size() * 3; + r.r3 = this.outboundSubscriber.onNextCalls; + r.r4 = this.requesterLeaseTracker.availableRequests; + r.r5 = this.source.payload.refCnt(); + + this.outboundSubscriber.values.forEach(ByteBuf::release); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java new file mode 100644 index 000000000..883077f77 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscriber.java @@ -0,0 +1,472 @@ +/* + * Copyright (c) 2020-Present Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static reactor.core.publisher.Operators.addCap; + +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +public class StressSubscriber implements CoreSubscriber { + + enum Operation { + ON_NEXT, + ON_ERROR, + ON_COMPLETE, + ON_SUBSCRIBE + } + + final Context context; + final int requestedFusionMode; + + int fusionMode; + Subscription subscription; + + public Throwable error; + public boolean done; + + public List droppedErrors = new CopyOnWriteArrayList<>(); + + public List values = new ArrayList<>(); + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(StressSubscriber.class, "requested"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "wip"); + + public volatile Operation guard; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater GUARD = + AtomicReferenceFieldUpdater.newUpdater(StressSubscriber.class, Operation.class, "guard"); + + public volatile boolean concurrentOnNext; + + public volatile boolean concurrentOnError; + + public volatile boolean concurrentOnComplete; + + public volatile boolean concurrentOnSubscribe; + + public volatile int onNextCalls; + + public BlockingQueue q = new LinkedBlockingDeque<>(); + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_NEXT_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onNextCalls"); + + public volatile int onNextDiscarded; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_NEXT_DISCARDED = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onNextDiscarded"); + + public volatile int onErrorCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_ERROR_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onErrorCalls"); + + public volatile int onCompleteCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_COMPLETE_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onCompleteCalls"); + + public volatile int onSubscribeCalls; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater ON_SUBSCRIBE_CALLS = + AtomicIntegerFieldUpdater.newUpdater(StressSubscriber.class, "onSubscribeCalls"); + + /** Build a {@link StressSubscriber} that makes an unbounded request upon subscription. */ + public StressSubscriber() { + this(Long.MAX_VALUE, Fuseable.NONE); + } + + /** + * Build a {@link StressSubscriber} that requests the provided amount in {@link + * #onSubscribe(Subscription)}. Use {@code 0} to avoid any initial request upon subscription. + * + * @param initRequest the requested amount upon subscription, or zero to disable initial request + */ + public StressSubscriber(long initRequest) { + this(initRequest, Fuseable.NONE); + } + + /** + * Build a {@link StressSubscriber} that requests the provided amount in {@link + * #onSubscribe(Subscription)}. Use {@code 0} to avoid any initial request upon subscription. + * + * @param initRequest the requested amount upon subscription, or zero to disable initial request + */ + public StressSubscriber(long initRequest, int requestedFusionMode) { + this.requestedFusionMode = requestedFusionMode; + this.context = + Operators.enableOnDiscard( + Context.of( + "reactor.onErrorDropped.local", + (Consumer) throwable -> droppedErrors.add(throwable)), + (__) -> ON_NEXT_DISCARDED.incrementAndGet(this)); + REQUESTED.lazySet(this, initRequest | Long.MIN_VALUE); + } + + @Override + public Context currentContext() { + return this.context; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (!GUARD.compareAndSet(this, null, Operation.ON_SUBSCRIBE)) { + concurrentOnSubscribe = true; + subscription.cancel(); + } else { + final boolean isValid = Operators.validate(this.subscription, subscription); + if (isValid) { + this.subscription = subscription; + } + GUARD.compareAndSet(this, Operation.ON_SUBSCRIBE, null); + + if (this.requestedFusionMode > 0 && subscription instanceof Fuseable.QueueSubscription) { + final int m = + ((Fuseable.QueueSubscription) subscription).requestFusion(this.requestedFusionMode); + final long requested = this.requested; + this.fusionMode = m; + if (m != Fuseable.NONE) { + if (requested == Long.MAX_VALUE) { + subscription.cancel(); + } + drain(); + return; + } + } + + if (isValid) { + long delivered = 0; + for (; ; ) { + long s = requested; + if (s == Long.MAX_VALUE) { + subscription.cancel(); + break; + } + + long r = s & Long.MAX_VALUE; + long toRequest = r - delivered; + if (toRequest > 0) { + subscription.request(toRequest); + delivered = r; + } + + if (REQUESTED.compareAndSet(this, s, 0)) { + break; + } + } + } + } + ON_SUBSCRIBE_CALLS.incrementAndGet(this); + } + + @Override + public void onNext(T value) { + if (fusionMode == Fuseable.ASYNC) { + drain(); + return; + } + + if (!GUARD.compareAndSet(this, null, Operation.ON_NEXT)) { + concurrentOnNext = true; + } else { + values.add(value); + GUARD.compareAndSet(this, Operation.ON_NEXT, null); + } + ON_NEXT_CALLS.incrementAndGet(this); + } + + @Override + public void onError(Throwable throwable) { + if (!GUARD.compareAndSet(this, null, Operation.ON_ERROR)) { + concurrentOnError = true; + } else { + GUARD.compareAndSet(this, Operation.ON_ERROR, null); + } + + if (done) { + throw new IllegalStateException("Already done"); + } + + error = throwable; + done = true; + q.offer(throwable); + ON_ERROR_CALLS.incrementAndGet(this); + + if (fusionMode == Fuseable.ASYNC) { + drain(); + } + } + + @Override + public void onComplete() { + if (!GUARD.compareAndSet(this, null, Operation.ON_COMPLETE)) { + concurrentOnComplete = true; + } else { + GUARD.compareAndSet(this, Operation.ON_COMPLETE, null); + } + if (done) { + throw new IllegalStateException("Already done"); + } + + done = true; + ON_COMPLETE_CALLS.incrementAndGet(this); + + if (fusionMode == Fuseable.ASYNC) { + drain(); + } + } + + public void request(long n) { + if (Operators.validate(n)) { + for (; ; ) { + final long s = this.requested; + if (s == 0) { + this.subscription.request(n); + return; + } + + if ((s & Long.MIN_VALUE) != Long.MIN_VALUE) { + return; + } + + final long r = s & Long.MAX_VALUE; + if (r == Long.MAX_VALUE) { + return; + } + + final long u = addCap(r, n); + if (REQUESTED.compareAndSet(this, s, u | Long.MIN_VALUE)) { + if (this.fusionMode != Fuseable.NONE) { + drain(); + } + return; + } + } + } + } + + public void cancel() { + for (; ; ) { + long s = this.requested; + if (s == 0) { + this.subscription.cancel(); + return; + } + + if (REQUESTED.compareAndSet(this, s, Long.MAX_VALUE)) { + if (this.fusionMode != Fuseable.NONE) { + drain(); + } + return; + } + } + } + + @SuppressWarnings("unchecked") + private void drain() { + final int previousState = markWorkAdded(); + if (isFinalized(previousState)) { + ((Queue) this.subscription).clear(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + final Subscription s = this.subscription; + final Queue q = (Queue) s; + + int expectedState = previousState + 1; + for (; ; ) { + long r = this.requested & Long.MAX_VALUE; + long e = 0L; + + while (r != e) { + // done has to be read before queue.poll to ensure there was no racing: + // Thread1: <#drain>: queue.poll(null) --------------------> this.done(true) + // Thread2: ------------------> <#onNext(V)> --> <#onComplete()> + boolean done = this.done; + + final T t = q.poll(); + final boolean empty = t == null; + + if (checkTerminated(done, empty)) { + if (!empty) { + values.add(t); + } + return; + } + + if (empty) { + break; + } + + values.add(t); + + e++; + } + + if (r == e) { + // done has to be read before queue.isEmpty to ensure there was no racing: + // Thread1: <#drain>: queue.isEmpty(true) --------------------> this.done(true) + // Thread2: --------------------> <#onNext(V)> ---> <#onComplete()> + boolean done = this.done; + boolean empty = q.isEmpty(); + + if (checkTerminated(done, empty)) { + return; + } + } + + if (e != 0) { + ON_NEXT_CALLS.addAndGet(this, (int) e); + if (r != Long.MAX_VALUE) { + produce(e); + } + } + + expectedState = markWorkDone(expectedState); + if (!isWorkInProgress(expectedState)) { + return; + } + } + } + + boolean checkTerminated(boolean done, boolean empty) { + final long state = this.requested; + if (state == Long.MAX_VALUE) { + this.subscription.cancel(); + clearAndFinalize(); + return true; + } + + if (done && empty) { + clearAndFinalize(); + return true; + } + + return false; + } + + final void produce(long produced) { + for (; ; ) { + final long s = this.requested; + + if ((s & Long.MIN_VALUE) != Long.MIN_VALUE) { + return; + } + + final long r = s & Long.MAX_VALUE; + if (r == Long.MAX_VALUE) { + return; + } + + final long u = r - produced; + if (REQUESTED.compareAndSet(this, s, u | Long.MIN_VALUE)) { + return; + } + } + } + + @SuppressWarnings("unchecked") + final void clearAndFinalize() { + final Queue q = (Queue) this.subscription; + for (; ; ) { + final int state = this.wip; + + q.clear(); + + if (WIP.compareAndSet(this, state, Integer.MIN_VALUE)) { + return; + } + } + } + + final int markWorkAdded() { + for (; ; ) { + final int state = this.wip; + + if (isFinalized(state)) { + return state; + } + + int nextState = state + 1; + if ((nextState & Integer.MAX_VALUE) == 0) { + return state; + } + + if (WIP.compareAndSet(this, state, nextState)) { + return state; + } + } + } + + final int markWorkDone(int expectedState) { + for (; ; ) { + final int state = this.wip; + + if (expectedState != state) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + if (WIP.compareAndSet(this, state, 0)) { + return 0; + } + } + } + + static boolean isFinalized(int state) { + return state == Integer.MIN_VALUE; + } + + static boolean isWorkInProgress(int state) { + return (state & Integer.MAX_VALUE) > 0; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java new file mode 100644 index 000000000..3b51b8ef6 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/StressSubscription.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2020-Present Pivotal Software Inc, All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Operators; + +public class StressSubscription implements Subscription { + + CoreSubscriber actual; + + public volatile int subscribes; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater SUBSCRIBES = + AtomicIntegerFieldUpdater.newUpdater(StressSubscription.class, "subscribes"); + + public volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(StressSubscription.class, "requested"); + + public volatile int requestsCount; + + @SuppressWarnings("rawtype s") + static final AtomicIntegerFieldUpdater REQUESTS_COUNT = + AtomicIntegerFieldUpdater.newUpdater(StressSubscription.class, "requestsCount"); + + public volatile boolean cancelled; + + void subscribe(CoreSubscriber actual) { + this.actual = actual; + actual.onSubscribe(this); + SUBSCRIBES.getAndIncrement(this); + } + + @Override + public void request(long n) { + REQUESTS_COUNT.incrementAndGet(this); + Operators.addCap(REQUESTED, this, n); + } + + @Override + public void cancel() { + cancelled = true; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java new file mode 100644 index 000000000..420da66ba --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -0,0 +1,39 @@ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import reactor.util.annotation.Nullable; + +public class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { + + @Nullable private final RequesterLeaseTracker requesterLeaseTracker; + + public TestRequesterResponderSupport( + DuplexConnection connection, StreamIdSupplier streamIdSupplier) { + this(connection, streamIdSupplier, null); + } + + public TestRequesterResponderSupport( + DuplexConnection connection, + StreamIdSupplier streamIdSupplier, + @Nullable RequesterLeaseTracker requesterLeaseTracker) { + super( + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + PayloadDecoder.ZERO_COPY, + connection, + streamIdSupplier, + __ -> null); + this.requesterLeaseTracker = requesterLeaseTracker; + } + + @Override + @Nullable + public RequesterLeaseTracker getRequesterLeaseTracker() { + return this.requesterLeaseTracker; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java b/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java new file mode 100644 index 000000000..22c478979 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/core/UnpooledByteBufPayload.java @@ -0,0 +1,155 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class UnpooledByteBufPayload extends AbstractReferenceCounted implements Payload { + + private final ByteBuf data; + private final ByteBuf metadata; + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data) { + return create(data, ByteBufAllocator.DEFAULT); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data, ByteBufAllocator allocator) { + return new UnpooledByteBufPayload(ByteBufUtil.writeUtf8(allocator, data), null); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata) { + return create(data, metadata, ByteBufAllocator.DEFAULT); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata, ByteBufAllocator allocator) { + return new UnpooledByteBufPayload( + ByteBufUtil.writeUtf8(allocator, data), + metadata == null ? null : ByteBufUtil.writeUtf8(allocator, metadata)); + } + + public UnpooledByteBufPayload(ByteBuf data, @Nullable ByteBuf metadata) { + this.data = data; + this.metadata = metadata; + } + + @Override + public boolean hasMetadata() { + ensureAccessible(); + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); + } + + @Override + public ByteBuf data() { + ensureAccessible(); + return data; + } + + @Override + public ByteBuf metadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + ensureAccessible(); + return data.slice(); + } + + @Override + public UnpooledByteBufPayload retain() { + super.retain(); + return this; + } + + @Override + public UnpooledByteBufPayload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public UnpooledByteBufPayload touch() { + ensureAccessible(); + data.touch(); + if (metadata != null) { + metadata.touch(); + } + return this; + } + + @Override + public UnpooledByteBufPayload touch(Object hint) { + ensureAccessible(); + data.touch(hint); + if (metadata != null) { + metadata.touch(hint); + } + return this; + } + + @Override + protected void deallocate() { + data.release(); + if (metadata != null) { + metadata.release(); + } + } + + /** + * Should be called by every method that tries to access the buffers content to check if the + * buffer was released before. + */ + void ensureAccessible() { + if (!isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + /** + * Used internally by {@link UnpooledByteBufPayload#ensureAccessible()} to try to guard against + * using the buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java new file mode 100644 index 000000000..a2d9fcf4d --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/internal/UnboundedProcessorStressTest.java @@ -0,0 +1,1733 @@ +package io.rsocket.internal; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.core.StressSubscriber; +import io.rsocket.utils.FastLogger; +import java.util.Arrays; +import java.util.ConcurrentModificationException; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.Expect; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LLLL_Result; +import org.openjdk.jcstress.infra.results.LLL_Result; +import org.openjdk.jcstress.infra.results.L_Result; +import reactor.core.Fuseable; +import reactor.core.publisher.Hooks; +import reactor.util.Logger; + +public abstract class UnboundedProcessorStressTest { + + static { + Hooks.onErrorDropped(t -> {}); + } + + final Logger logger = new FastLogger(getClass().getName()); + + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(logger); + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class SmokeStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class SmokeFusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke2StressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + if (stressSubscriber.onCompleteCalls > 0 && stressSubscriber.onErrorCalls > 0) { + throw new RuntimeException("boom"); + } + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke24StressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @State + public static class Smoke2FusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", + "1, 1, 0", + "2, 1, 0", + "3, 1, 0", + "4, 1, 0", + + // dropped error scenarios + "0, 4, 0", + "1, 4, 0", + "2, 4, 0", + "3, 4, 0", + "4, 4, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete() before dispose() || onError()") + @Outcome( + id = { + "0, 2, 0", "1, 2, 0", "2, 2, 0", "3, 2, 0", "4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onError() before dispose() || onComplete()") + @Outcome( + id = { + "0, 2, 0", + "1, 2, 0", + "2, 2, 0", + "3, 2, 0", + "4, 2, 0", + // dropped error + "0, 5, 0", + "1, 5, 0", + "2, 5, 0", + "3, 5, 0", + "4, 5, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before onError() || onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with dispose() || onError() || onComplete()") + @State + public static class Smoke21FusedStressTest extends UnboundedProcessorStressTest { + + static final RuntimeException testException = new RuntimeException("test"); + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Actor + public void error() { + unboundedProcessor.onError(testException); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @Outcome( + id = { + "0, 0, 0", + "1, 0, 0", + "2, 0, 0", + "3, 0, 0", + "4, 0, 0", + // interleave with error or complete happened first but dispose suppressed them + "0, 3, 0", + "1, 3, 0", + "2, 3, 0", + "3, 3, 0", + "4, 3, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "cancel() before or interleave with onComplete()") + @State + public static class Smoke30StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void subscribeAndRequest() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke31StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void subscribeAndRequest() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + if (stressSubscriber.concurrentOnNext || stressSubscriber.concurrentOnComplete) { + throw new ConcurrentModificationException("boo"); + } + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0", "1, 1, 0", "2, 1, 0", "3, 1, 0", "4, 1, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke32StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = + new StressSubscriber<>(Long.MAX_VALUE, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.onComplete(); + } + + @Arbiter + public void arbiter(LLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "0, 1, 0, 5", + "1, 1, 0, 5", + "2, 1, 0, 5", + "3, 1, 0, 5", + "4, 1, 0, 5", + "5, 1, 0, 5", + }, + expect = Expect.ACCEPTABLE, + desc = "onComplete()") + @State + public static class Smoke33StressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = + new StressSubscriber<>(Long.MAX_VALUE, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + final ByteBuf byteBuf5 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(5); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void next1() { + unboundedProcessor.tryEmitNormal(byteBuf1); + unboundedProcessor.tryEmitPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.tryEmitPrioritized(byteBuf3); + unboundedProcessor.tryEmitNormal(byteBuf4); + } + + @Actor + public void complete() { + unboundedProcessor.tryEmitFinal(byteBuf5); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = stressSubscriber.onNextCalls; + r.r2 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + r.r4 = stressSubscriber.values.get(stressSubscriber.values.size() - 1).readByte(); + stressSubscriber.values.forEach(ByteBuf::release); + + r.r3 = + byteBuf1.refCnt() + + byteBuf2.refCnt() + + byteBuf3.refCnt() + + byteBuf4.refCnt() + + byteBuf5.refCnt(); + } + } + + @JCStressTest + @Outcome( + id = { + "-2954361355555045376, 4, 2, 0", + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 4, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 4, 0, 0", + "-7854277750134145024, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 2, 0", + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 3, 0, 0", + "-7854277750134145024, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 2, 0", + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 2, 0, 0", + "-7854277750134145024, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 2, 0", + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 1, 0, 0", + "-7854277750134145024, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 2, 0", + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 0, 0, 0", + "-7854277750134145024, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class RequestVsCancelVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class RequestVsCancelVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + { + unboundedProcessor.subscribe(stressSubscriber); + } + + @Actor + public void request() { + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void cancel() { + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-2954361355555045376, 4, 2, 0", + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + "-4539628424389459968, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 4, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 4, 0, 0", + "-7854277750134145024, 4, 0, 0", + "-4539628424389459968, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 2, 0", + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + "-4539628424389459968, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 3, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 3, 0, 0", + "-7854277750134145024, 3, 0, 0", + "-4539628424389459968, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 2, 0", + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + "-4539628424389459968, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 2, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 2, 0, 0", + "-7854277750134145024, 2, 0, 0", + "-4539628424389459968, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 2, 0", + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + "-4539628424389459968, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 1, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 1, 0, 0", + "-7854277750134145024, 1, 0, 0", + "-4539628424389459968, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 2, 0", + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + "-4539628424389459968, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-2954361355555045376, 0, 0, 0", // here, dispose is earlier, but it was late to deliver + // error signal in the drainLoop + "-7566047373982433280, 0, 0, 0", + "-7854277750134145024, 0, 0, 0", + "-4539628424389459968, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class SubscribeWithFollowingRequestsVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = { + "-3242591731706757120, 4, 2, 0", + "-4107282860161892352, 4, 2, 0", + "-4395513236313604096, 4, 2, 0", + "-4539628424389459968, 4, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 4, 0, 0", + "-4539628424389459968, 4, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 3, 2, 0", + "-4107282860161892352, 3, 2, 0", + "-4395513236313604096, 3, 2, 0", + "-4539628424389459968, 3, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 3, 0, 0", + "-4539628424389459968, 3, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 2, 2, 0", + "-4107282860161892352, 2, 2, 0", + "-4395513236313604096, 2, 2, 0", + "-4539628424389459968, 2, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 2, 0, 0", + "-4539628424389459968, 2, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1, buf2) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 1, 2, 0", + "-4107282860161892352, 1, 2, 0", + "-4395513236313604096, 1, 2, 0", + "-4539628424389459968, 1, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 1, 0, 0", + "-4539628424389459968, 1, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @Outcome( + id = { + "-3242591731706757120, 0, 2, 0", + "-4107282860161892352, 0, 2, 0", + "-4395513236313604096, 0, 2, 0", + "-4539628424389459968, 0, 2, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = { + "-7854277750134145024, 0, 0, 0", + "-4539628424389459968, 0, 0, 0", + }, + expect = Expect.ACCEPTABLE, + desc = "next(buf1) -> cancel() before anything") + @State + public static class SubscribeWithFollowingRequestsVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndRequest() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + stressSubscriber.request(1); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"-4539628424389459968, 0, 2, 0", "-3386706919782612992, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = {"-4395513236313604096, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> dispose() before anything") + @Outcome( + id = {"-3242591731706757120, 0, 2, 0", "-3242591731706757120, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> (dispose() || cancel())") + @Outcome( + id = {"-7854277750134145024, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> cancel() before anything") + @State + public static class SubscribeWithFollowingCancelVsOnNextVsDisposeStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.NONE); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndCancel() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"-4539628424389459968, 0, 2, 0", "-3386706919782612992, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "dispose() before anything") + @Outcome( + id = {"-4395513236313604096, 0, 2, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> dispose() before anything") + @Outcome( + id = {"-3242591731706757120, 0, 2, 0", "-3242591731706757120, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> (dispose() || cancel())") + @Outcome( + id = {"-7854277750134145024, 0, 0, 0"}, + expect = Expect.ACCEPTABLE, + desc = "subscribe() -> cancel() before anything") + @State + public static class SubscribeWithFollowingCancelVsOnNextVsDisposeFusedStressTest + extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber = new StressSubscriber<>(0, Fuseable.ANY); + final ByteBuf byteBuf1 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(1); + final ByteBuf byteBuf2 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(2); + final ByteBuf byteBuf3 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(3); + final ByteBuf byteBuf4 = UnpooledByteBufAllocator.DEFAULT.buffer().writeByte(4); + + @Actor + public void subscribeAndCancel() { + unboundedProcessor.subscribe(stressSubscriber); + stressSubscriber.cancel(); + } + + @Actor + public void dispose() { + unboundedProcessor.dispose(); + } + + @Actor + public void next1() { + unboundedProcessor.onNext(byteBuf1); + unboundedProcessor.onNextPrioritized(byteBuf2); + } + + @Actor + public void next2() { + unboundedProcessor.onNextPrioritized(byteBuf3); + unboundedProcessor.onNext(byteBuf4); + } + + @Arbiter + public void arbiter(LLLL_Result r) { + r.r1 = unboundedProcessor.state; + r.r2 = stressSubscriber.onNextCalls; + r.r3 = + stressSubscriber.onCompleteCalls + + stressSubscriber.onErrorCalls * 2 + + stressSubscriber.droppedErrors.size() * 3; + + stressSubscriber.values.forEach(ByteBuf::release); + + r.r4 = byteBuf1.refCnt() + byteBuf2.refCnt() + byteBuf3.refCnt() + byteBuf4.refCnt(); + + checkOutcomes(this, r.toString(), logger); + } + } + + @JCStressTest + @Outcome( + id = {"1"}, + expect = Expect.ACCEPTABLE) + @State + public static class SubscribeVsSubscribeStressTest extends UnboundedProcessorStressTest { + + final StressSubscriber stressSubscriber1 = new StressSubscriber<>(0, Fuseable.NONE); + final StressSubscriber stressSubscriber2 = new StressSubscriber<>(0, Fuseable.NONE); + + @Actor + public void subscribe1() { + unboundedProcessor.subscribe(stressSubscriber1); + } + + @Actor + public void subscribe2() { + unboundedProcessor.subscribe(stressSubscriber2); + } + + @Arbiter + public void arbiter(L_Result r) { + r.r1 = stressSubscriber1.onErrorCalls + stressSubscriber2.onErrorCalls; + + checkOutcomes(this, r.toString(), logger); + } + } + + static void checkOutcomes(Object instance, String result, Logger logger) { + if (Arrays.stream(instance.getClass().getDeclaredAnnotationsByType(Outcome.class)) + .flatMap(o -> Arrays.stream(o.id())) + .noneMatch(s -> s.equalsIgnoreCase(result))) { + throw new RuntimeException(result + " " + logger); + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java b/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java new file mode 100644 index 000000000..f0b209552 --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/resume/InMemoryResumableFramesStoreStressTest.java @@ -0,0 +1,118 @@ +package io.rsocket.resume; + +import static org.openjdk.jcstress.annotations.Expect.ACCEPTABLE; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import org.openjdk.jcstress.annotations.Actor; +import org.openjdk.jcstress.annotations.Arbiter; +import org.openjdk.jcstress.annotations.JCStressTest; +import org.openjdk.jcstress.annotations.Outcome; +import org.openjdk.jcstress.annotations.State; +import org.openjdk.jcstress.infra.results.LL_Result; +import reactor.core.Disposable; + +public class InMemoryResumableFramesStoreStressTest { + boolean storeClosed; + + InMemoryResumableFramesStore store = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 128); + boolean processorClosed; + UnboundedProcessor processor = new UnboundedProcessor(() -> processorClosed = true); + + void subscribe() { + store.saveFrames(processor).subscribe(); + store.onClose().subscribe(null, t -> storeClosed = true, () -> storeClosed = true); + } + + @JCStressTest + @Outcome( + id = {"true, true"}, + expect = ACCEPTABLE) + @State + public static class TwoSubscribesRaceStressTest extends InMemoryResumableFramesStoreStressTest { + + Disposable d1; + + final ByteBuf b1 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello1"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello2")); + final ByteBuf b2 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 3, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello3"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello4")); + final ByteBuf b3 = + PayloadFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 5, + false, + true, + false, + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello5"), + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "hello6")); + + final ByteBuf c1 = + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new ConnectionErrorException("closed")); + + { + subscribe(); + d1 = store.doOnDiscard(ByteBuf.class, ByteBuf::release).subscribe(ByteBuf::release, t -> {}); + } + + @Actor + public void producer1() { + processor.tryEmitNormal(b1); + processor.tryEmitNormal(b2); + processor.tryEmitNormal(b3); + } + + @Actor + public void producer2() { + processor.tryEmitFinal(c1); + } + + @Actor + public void producer3() { + d1.dispose(); + store + .doOnDiscard(ByteBuf.class, ByteBuf::release) + .subscribe(ByteBuf::release, t -> {}) + .dispose(); + store + .doOnDiscard(ByteBuf.class, ByteBuf::release) + .subscribe(ByteBuf::release, t -> {}) + .dispose(); + store.doOnDiscard(ByteBuf.class, ByteBuf::release).subscribe(ByteBuf::release, t -> {}); + } + + @Actor + public void producer4() { + store.releaseFrames(0); + store.releaseFrames(0); + store.releaseFrames(0); + } + + @Arbiter + public void arbiter(LL_Result r) { + r.r1 = storeClosed; + r.r2 = processorClosed; + } + } +} diff --git a/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java b/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java new file mode 100644 index 000000000..c301d87cf --- /dev/null +++ b/rsocket-core/src/jcstress/java/io/rsocket/utils/FastLogger.java @@ -0,0 +1,137 @@ +package io.rsocket.utils; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; +import reactor.util.Logger; + +/** + * Implementation of {@link Logger} which is based on the {@link ThreadLocal} based queue which + * collects all the events on the per-thread basis.
Such logger is designed to have all events + * stored during the stress-test run and then sorted and printed out once all the Threads completed + * execution (inside the {@link org.openjdk.jcstress.annotations.Arbiter} annotated method.
+ * Note, this implementation only supports trace-level logs and ignores all others, it is intended + * to be used by {@link reactor.core.publisher.StateLogger}. + */ +public class FastLogger implements Logger { + + final Map> queues = new ConcurrentHashMap<>(); + + final ThreadLocal> logsQueueLocal = + ThreadLocal.withInitial( + () -> { + final ArrayList logs = new ArrayList<>(100); + queues.put(Thread.currentThread(), logs); + return logs; + }); + + private final String name; + + public FastLogger(String name) { + this.name = name; + } + + @Override + public String toString() { + return queues + .values() + .stream() + .flatMap(List::stream) + .sorted( + Comparator.comparingLong( + s -> { + Pattern pattern = Pattern.compile("\\[(.*?)]"); + Matcher matcher = pattern.matcher(s); + matcher.find(); + return Long.parseLong(matcher.group(1)); + })) + .collect(Collectors.joining("\n")); + } + + @Override + public String getName() { + return this.name; + } + + @Override + public boolean isTraceEnabled() { + return true; + } + + @Override + public void trace(String msg) { + logsQueueLocal.get().add(String.format("[%s] %s", System.nanoTime(), msg)); + } + + @Override + public void trace(String format, Object... arguments) { + trace(String.format(format, arguments)); + } + + @Override + public void trace(String msg, Throwable t) { + trace(String.format("%s, %s", msg, Arrays.toString(t.getStackTrace()))); + } + + @Override + public boolean isDebugEnabled() { + return false; + } + + @Override + public void debug(String msg) {} + + @Override + public void debug(String format, Object... arguments) {} + + @Override + public void debug(String msg, Throwable t) {} + + @Override + public boolean isInfoEnabled() { + return false; + } + + @Override + public void info(String msg) {} + + @Override + public void info(String format, Object... arguments) {} + + @Override + public void info(String msg, Throwable t) {} + + @Override + public boolean isWarnEnabled() { + return false; + } + + @Override + public void warn(String msg) {} + + @Override + public void warn(String format, Object... arguments) {} + + @Override + public void warn(String msg, Throwable t) {} + + @Override + public boolean isErrorEnabled() { + return false; + } + + @Override + public void error(String msg) {} + + @Override + public void error(String format, Object... arguments) {} + + @Override + public void error(String msg, Throwable t) {} +} diff --git a/rsocket-core/src/jcstress/resources/logback.xml b/rsocket-core/src/jcstress/resources/logback.xml new file mode 100644 index 000000000..e5877552c --- /dev/null +++ b/rsocket-core/src/jcstress/resources/logback.xml @@ -0,0 +1,39 @@ + + + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/rsocket-core/src/main/java/io/rsocket/Availability.java b/rsocket-core/src/main/java/io/rsocket/Availability.java new file mode 100644 index 000000000..3361bcf8d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/Availability.java @@ -0,0 +1,26 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +public interface Availability { + + /** + * @return a positive numbers representing the availability of the entity. Higher is better, 0.0 + * means not available + */ + double availability(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/Closeable.java b/rsocket-core/src/main/java/io/rsocket/Closeable.java new file mode 100644 index 000000000..2ea9a0371 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/Closeable.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import org.reactivestreams.Subscriber; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; + +/** An interface which allows listening to when a specific instance of this interface is closed */ +public interface Closeable extends Disposable { + /** + * Returns a {@link Mono} that terminates when the instance is terminated by any reason. Note, in + * case of error termination, the cause of error will be propagated as an error signal through + * {@link org.reactivestreams.Subscriber#onError(Throwable)}. Otherwise, {@link + * Subscriber#onComplete()} will be called. + * + * @return a {@link Mono} to track completion with success or error of the underlying resource. + * When the underlying resource is an `RSocket`, the {@code Mono} exposes stream 0 (i.e. + * connection level) errors. + */ + Mono onClose(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java new file mode 100644 index 000000000..c39e679a1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/ConnectionSetupPayload.java @@ -0,0 +1,59 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import io.netty.buffer.ByteBuf; +import io.netty.util.AbstractReferenceCounted; +import reactor.util.annotation.Nullable; + +/** + * Exposes information from the {@code SETUP} frame to a server, as well as to client responders. + */ +public abstract class ConnectionSetupPayload extends AbstractReferenceCounted implements Payload { + + public abstract String metadataMimeType(); + + public abstract String dataMimeType(); + + public abstract int keepAliveInterval(); + + public abstract int keepAliveMaxLifetime(); + + public abstract int getFlags(); + + public abstract boolean willClientHonorLease(); + + public abstract boolean isResumeEnabled(); + + @Nullable + public abstract ByteBuf resumeToken(); + + @Override + public ConnectionSetupPayload retain() { + super.retain(); + return this; + } + + @Override + public ConnectionSetupPayload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public abstract ConnectionSetupPayload touch(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java new file mode 100644 index 000000000..fe91f4bf0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/DuplexConnection.java @@ -0,0 +1,93 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import org.reactivestreams.Subscriber; +import reactor.core.publisher.Flux; + +/** Represents a connection with input/output that the protocol uses. */ +public interface DuplexConnection extends Availability, Closeable { + + /** + * Delivers the given frame to the underlying transport connection. This method is non-blocking + * and can be safely executed from multiple threads. This method does not provide any flow-control + * mechanism. + * + * @param streamId to which the given frame relates + * @param frame with the encoded content + */ + void sendFrame(int streamId, ByteBuf frame); + + /** + * Send an error frame and after it is successfully sent, close the connection. + * + * @param errorException to encode in the error frame + */ + void sendErrorAndClose(RSocketErrorException errorException); + + /** + * Returns a stream of all {@code Frame}s received on this connection. + * + *

Completion + * + *

Returned {@code Publisher} MUST never emit a completion event ({@link + * Subscriber#onComplete()}). + * + *

Error + * + *

Returned {@code Publisher} can error with various transport errors. If the underlying + * physical connection is closed by the peer, then the returned stream from here MUST + * emit an {@link ClosedChannelException}. + * + *

Multiple Subscriptions + * + *

Returned {@code Publisher} is not required to support multiple concurrent subscriptions. + * RSocket will never have multiple subscriptions to this source. Implementations MUST + * emit an {@link IllegalStateException} for subsequent concurrent subscriptions, if they do not + * support multiple concurrent subscriptions. + * + * @return Stream of all {@code Frame}s received. + */ + Flux receive(); + + /** + * Returns the assigned {@link ByteBufAllocator}. + * + * @return the {@link ByteBufAllocator} + */ + ByteBufAllocator alloc(); + + /** + * Return the remote address that this connection is connected to. The returned {@link + * SocketAddress} varies by transport type and should be downcast to obtain more detailed + * information. For TCP and WebSocket, the address type is {@link java.net.InetSocketAddress}. For + * local transport, it is {@link io.rsocket.transport.local.LocalSocketAddress}. + * + * @return the address + * @since 1.1 + */ + SocketAddress remoteAddress(); + + @Override + default double availability() { + return isDisposed() ? 0.0 : 1.0; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/Payload.java b/rsocket-core/src/main/java/io/rsocket/Payload.java new file mode 100644 index 000000000..fc130528e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/Payload.java @@ -0,0 +1,104 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import io.netty.buffer.ByteBuf; +import io.netty.util.ReferenceCounted; +import io.netty.util.ResourceLeakDetector; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +/** Payload of a Frame . */ +public interface Payload extends ReferenceCounted { + /** + * Returns whether the payload has metadata, useful for tell if metadata is empty or not present. + * + * @return whether payload has non-null (possibly empty) metadata + */ + boolean hasMetadata(); + + /** + * Returns a slice Payload metadata. Always non-null, check {@link #hasMetadata()} to + * differentiate null from "". + * + * @return payload metadata. + */ + ByteBuf sliceMetadata(); + + /** + * Returns the Payload data. Always non-null. + * + * @return payload data. + */ + ByteBuf sliceData(); + + /** + * Returns the Payloads' data without slicing if possible. This is not safe and editing this could + * effect the payload. It is recommended to call sliceData(). + * + * @return data as a bytebuf or slice of the data + */ + ByteBuf data(); + + /** + * Returns the Payloads' metadata without slicing if possible. This is not safe and editing this + * could effect the payload. It is recommended to call sliceMetadata(). + * + * @return metadata as a bytebuf or slice of the metadata + */ + ByteBuf metadata(); + + /** Increases the reference count by {@code 1}. */ + @Override + Payload retain(); + + /** Increases the reference count by the specified {@code increment}. */ + @Override + Payload retain(int increment); + + /** + * Records the current access location of this object for debugging purposes. If this object is + * determined to be leaked, the information recorded by this operation will be provided to you via + * {@link ResourceLeakDetector}. This method is a shortcut to {@link #touch(Object) touch(null)}. + */ + @Override + Payload touch(); + + /** + * Records the current access location of this object with an additional arbitrary information for + * debugging purposes. If this object is determined to be leaked, the information recorded by this + * operation will be provided to you via {@link ResourceLeakDetector}. + */ + @Override + Payload touch(Object hint); + + default ByteBuffer getMetadata() { + return sliceMetadata().nioBuffer(); + } + + default ByteBuffer getData() { + return sliceData().nioBuffer(); + } + + default String getMetadataUtf8() { + return sliceMetadata().toString(StandardCharsets.UTF_8); + } + + default String getDataUtf8() { + return sliceData().toString(StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocket.java b/rsocket-core/src/main/java/io/rsocket/RSocket.java new file mode 100644 index 000000000..b05241365 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocket.java @@ -0,0 +1,99 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * A contract providing different interaction models for RSocket protocol. + */ +public interface RSocket extends Availability, Closeable { + + /** + * Fire and Forget interaction model of {@code RSocket}. + * + * @param payload Request payload. + * @return {@code Publisher} that completes when the passed {@code payload} is successfully + * handled, otherwise errors. + */ + default Mono fireAndForget(Payload payload) { + return RSocketAdapter.fireAndForget(payload); + } + + /** + * Request-Response interaction model of {@code RSocket}. + * + * @param payload Request payload. + * @return {@code Publisher} containing at most a single {@code Payload} representing the + * response. + */ + default Mono requestResponse(Payload payload) { + return RSocketAdapter.requestResponse(payload); + } + + /** + * Request-Stream interaction model of {@code RSocket}. + * + * @param payload Request payload. + * @return {@code Publisher} containing the stream of {@code Payload}s representing the response. + */ + default Flux requestStream(Payload payload) { + return RSocketAdapter.requestStream(payload); + } + + /** + * Request-Channel interaction model of {@code RSocket}. + * + * @param payloads Stream of request payloads. + * @return Stream of response payloads. + */ + default Flux requestChannel(Publisher payloads) { + return RSocketAdapter.requestChannel(payloads); + } + + /** + * Metadata-Push interaction model of {@code RSocket}. + * + * @param payload Request payloads. + * @return {@code Publisher} that completes when the passed {@code payload} is successfully + * handled, otherwise errors. + */ + default Mono metadataPush(Payload payload) { + return RSocketAdapter.metadataPush(payload); + } + + @Override + default double availability() { + return isDisposed() ? 0.0 : 1.0; + } + + @Override + default void dispose() {} + + @Override + default boolean isDisposed() { + return false; + } + + @Override + default Mono onClose() { + return Mono.never(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java b/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java new file mode 100644 index 000000000..b5a64b8dd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocketAdapter.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Package private class with default implementations for use in {@link RSocket}. The main purpose + * is to hide static {@link UnsupportedOperationException} declarations. + * + * @since 1.0.3 + */ +class RSocketAdapter { + + private static final Mono UNSUPPORTED_FIRE_AND_FORGET = + Mono.error(new UnsupportedInteractionException("Fire-and-Forget")); + + private static final Mono UNSUPPORTED_REQUEST_RESPONSE = + Mono.error(new UnsupportedInteractionException("Request-Response")); + + private static final Flux UNSUPPORTED_REQUEST_STREAM = + Flux.error(new UnsupportedInteractionException("Request-Stream")); + + private static final Flux UNSUPPORTED_REQUEST_CHANNEL = + Flux.error(new UnsupportedInteractionException("Request-Channel")); + + private static final Mono UNSUPPORTED_METADATA_PUSH = + Mono.error(new UnsupportedInteractionException("Metadata-Push")); + + static Mono fireAndForget(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_FIRE_AND_FORGET; + } + + static Mono requestResponse(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_REQUEST_RESPONSE; + } + + static Flux requestStream(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_REQUEST_STREAM; + } + + static Flux requestChannel(Publisher payloads) { + return RSocketAdapter.UNSUPPORTED_REQUEST_CHANNEL; + } + + static Mono metadataPush(Payload payload) { + payload.release(); + return RSocketAdapter.UNSUPPORTED_METADATA_PUSH; + } + + private static class UnsupportedInteractionException extends RuntimeException { + + private static final long serialVersionUID = 5084623297446471999L; + + UnsupportedInteractionException(String interactionName) { + super(interactionName + " not implemented.", null, false, false); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java new file mode 100644 index 000000000..b43b14bae --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/RSocketErrorException.java @@ -0,0 +1,82 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import reactor.util.annotation.Nullable; + +/** + * Exception that represents an RSocket protocol error. + * + * @see ERROR + * Frame (0x0B) + */ +public class RSocketErrorException extends RuntimeException { + + private static final long serialVersionUID = -1628781753426267554L; + + private static final int MIN_ERROR_CODE = 0x00000001; + + private static final int MAX_ERROR_CODE = 0xFFFFFFFE; + + private final int errorCode; + + /** + * Constructor with a protocol error code and a message. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + */ + public RSocketErrorException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Alternative to {@link #RSocketErrorException(int, String)} with a root cause. + * + * @param errorCode the RSocket protocol error code + * @param message error explanation + * @param cause a root cause for the error + */ + public RSocketErrorException(int errorCode, String message, @Nullable Throwable cause) { + super(message, cause); + this.errorCode = errorCode; + if (errorCode > MAX_ERROR_CODE && errorCode < MIN_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000001-0xFFFFFFFE]", this); + } + } + + /** + * Return the RSocket error code + * represented by this exception + * + * @return the RSocket protocol error code + */ + public int errorCode() { + return errorCode; + } + + @Override + public String toString() { + return getClass().getSimpleName() + + " (0x" + + Integer.toHexString(errorCode) + + "): " + + getMessage(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java b/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java new file mode 100644 index 000000000..a42626e78 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/SocketAcceptor.java @@ -0,0 +1,93 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +import io.rsocket.exceptions.SetupException; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * RSocket is a full duplex protocol where a client and server are identical in terms of both having + * the capability to initiate requests to their peer. This interface provides the contract where a + * client or server handles the {@code setup} for a new connection and creates a responder {@code + * RSocket} for accepting requests from the remote peer. + */ +public interface SocketAcceptor { + + /** + * Handle the {@code SETUP} frame for a new connection and create a responder {@code RSocket} for + * handling requests from the remote peer. + * + * @param setup the {@code setup} received from a client in a server scenario, or in a client + * scenario this is the setup about to be sent to the server. + * @param sendingSocket socket for sending requests to the remote peer. + * @return {@code RSocket} to accept requests with. + * @throws SetupException If the acceptor needs to reject the setup of this socket. + */ + Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket); + + /** Create a {@code SocketAcceptor} that handles requests with the given {@code RSocket}. */ + static SocketAcceptor with(RSocket rsocket) { + return (setup, sendingRSocket) -> Mono.just(rsocket); + } + + /** Create a {@code SocketAcceptor} for fire-and-forget interactions with the given handler. */ + static SocketAcceptor forFireAndForget(Function> handler) { + return with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-response interactions with the given handler. */ + static SocketAcceptor forRequestResponse(Function> handler) { + return with( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-stream interactions with the given handler. */ + static SocketAcceptor forRequestStream(Function> handler) { + return with( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return handler.apply(payload); + } + }); + } + + /** Create a {@code SocketAcceptor} for request-channel interactions with the given handler. */ + static SocketAcceptor forRequestChannel(Function, Flux> handler) { + return with( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return handler.apply(payloads); + } + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java new file mode 100644 index 000000000..e19d31924 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ClientServerInputMultiplexer.java @@ -0,0 +1,348 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; + +/** + * {@link DuplexConnection#receive()} is a single stream on which the following type of frames + * arrive: + * + *

    + *
  • Frames for streams initiated by the initiator of the connection (client). + *
  • Frames for streams initiated by the acceptor of the connection (server). + *
+ * + *

The only way to differentiate these two frames is determining whether the stream Id is odd or + * even. Even IDs are for the streams initiated by server and odds are for streams initiated by the + * client. + */ +class ClientServerInputMultiplexer implements CoreSubscriber, Closeable { + + private final InternalDuplexConnection serverReceiver; + private final InternalDuplexConnection clientReceiver; + private final DuplexConnection serverConnection; + private final DuplexConnection clientConnection; + private final DuplexConnection source; + private final boolean isClient; + + private Subscription s; + + private Throwable t; + + private volatile int state; + private static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(ClientServerInputMultiplexer.class, "state"); + + public ClientServerInputMultiplexer( + DuplexConnection source, InitializingInterceptorRegistry registry, boolean isClient) { + this.source = source; + this.isClient = isClient; + + this.serverReceiver = new InternalDuplexConnection(Type.SERVER, this, source); + this.clientReceiver = new InternalDuplexConnection(Type.CLIENT, this, source); + this.serverConnection = registry.initConnection(Type.SERVER, serverReceiver); + this.clientConnection = registry.initConnection(Type.CLIENT, clientReceiver); + } + + DuplexConnection asServerConnection() { + return serverConnection; + } + + DuplexConnection asClientConnection() { + return clientConnection; + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(ByteBuf frame) { + int streamId = FrameHeaderCodec.streamId(frame); + final Type type; + if (streamId == 0) { + switch (FrameHeaderCodec.frameType(frame)) { + case LEASE: + case KEEPALIVE: + case ERROR: + type = isClient ? Type.CLIENT : Type.SERVER; + break; + default: + type = isClient ? Type.SERVER : Type.CLIENT; + } + } else if ((streamId & 0b1) == 0) { + type = Type.SERVER; + } else { + type = Type.CLIENT; + } + + switch (type) { + case CLIENT: + clientReceiver.onNext(frame); + break; + case SERVER: + serverReceiver.onNext(frame); + break; + } + } + + @Override + public void onComplete() { + final int previousState = STATE.getAndSet(this, Integer.MIN_VALUE); + if (previousState == Integer.MIN_VALUE || previousState == 0) { + return; + } + + if (clientReceiver.isSubscribed()) { + clientReceiver.onComplete(); + } + if (serverReceiver.isSubscribed()) { + serverReceiver.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + this.t = t; + + final int previousState = STATE.getAndSet(this, Integer.MIN_VALUE); + if (previousState == Integer.MIN_VALUE || previousState == 0) { + return; + } + + if (clientReceiver.isSubscribed()) { + clientReceiver.onError(t); + } + if (serverReceiver.isSubscribed()) { + serverReceiver.onError(t); + } + } + + boolean notifyRequested() { + final int currentState = incrementAndGetCheckingState(); + if (currentState == Integer.MIN_VALUE) { + return false; + } + + if (currentState == 2) { + source.receive().subscribe(this); + } + + return true; + } + + int incrementAndGetCheckingState() { + int prev, next; + for (; ; ) { + prev = this.state; + + if (prev == Integer.MIN_VALUE) { + return prev; + } + + next = prev + 1; + if (STATE.compareAndSet(this, prev, next)) { + return next; + } + } + } + + @Override + public String toString() { + return "ClientServerInputMultiplexer{" + + "serverReceiver=" + + serverReceiver + + ", clientReceiver=" + + clientReceiver + + ", serverConnection=" + + serverConnection + + ", clientConnection=" + + clientConnection + + ", source=" + + source + + ", isClient=" + + isClient + + ", s=" + + s + + ", t=" + + t + + ", state=" + + state + + '}'; + } + + private static class InternalDuplexConnection extends Flux + implements Subscription, DuplexConnection { + private final Type type; + private final ClientServerInputMultiplexer clientServerInputMultiplexer; + private final DuplexConnection source; + + private volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(InternalDuplexConnection.class, "state"); + + CoreSubscriber actual; + + public InternalDuplexConnection( + Type type, + ClientServerInputMultiplexer clientServerInputMultiplexer, + DuplexConnection source) { + this.type = type; + this.clientServerInputMultiplexer = clientServerInputMultiplexer; + this.source = source; + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (this.state == 0 && STATE.compareAndSet(this, 0, 1)) { + this.actual = actual; + actual.onSubscribe(this); + } else { + Operators.error( + actual, + new IllegalStateException("InternalDuplexConnection allows only single subscription")); + } + } + + @Override + public void request(long n) { + if (this.state == 1 && STATE.compareAndSet(this, 1, 2)) { + final ClientServerInputMultiplexer multiplexer = clientServerInputMultiplexer; + if (!multiplexer.notifyRequested()) { + final Throwable t = multiplexer.t; + if (t != null) { + this.actual.onError(t); + } else { + this.actual.onComplete(); + } + } + } + } + + @Override + public void cancel() { + // no ops + } + + void onNext(ByteBuf frame) { + this.actual.onNext(frame); + } + + void onComplete() { + this.actual.onComplete(); + } + + void onError(Throwable t) { + this.actual.onError(t); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + source.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return this; + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + public boolean isSubscribed() { + return this.state != 0; + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public double availability() { + return source.availability(); + } + + @Override + public String toString() { + return "InternalDuplexConnection{" + + "type=" + + type + + ", source=" + + source + + ", state=" + + state + + '}'; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java new file mode 100644 index 000000000..3477b8d6d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ClientSetup.java @@ -0,0 +1,49 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.DuplexConnection; +import java.nio.channels.ClosedChannelException; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +abstract class ClientSetup { + abstract Mono> init(DuplexConnection connection); +} + +class DefaultClientSetup extends ClientSetup { + + @Override + Mono> init(DuplexConnection connection) { + return Mono.create( + sink -> sink.onRequest(__ -> sink.success(Tuples.of(Unpooled.EMPTY_BUFFER, connection)))); + } +} + +class ResumableClientSetup extends ClientSetup { + + @Override + Mono> init(DuplexConnection connection) { + return Mono.create( + sink -> { + sink.onRequest( + __ -> { + new SetupHandlingDuplexConnection(connection, sink); + }); + + Disposable subscribe = + connection + .onClose() + .doFinally(__ -> sink.error(new ClosedChannelException())) + .subscribe(); + sink.onCancel( + () -> { + subscribe.dispose(); + connection.dispose(); + connection.receive().subscribe(); + }); + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java new file mode 100644 index 000000000..9b5647c6f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultConnectionSetupPayload.java @@ -0,0 +1,119 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.SetupFrameCodec; + +/** + * Default implementation of {@link ConnectionSetupPayload}. Primarily for internal use within + * RSocket Java but may be created in an application, e.g. for testing purposes. + */ +public class DefaultConnectionSetupPayload extends ConnectionSetupPayload { + + private final ByteBuf setupFrame; + + public DefaultConnectionSetupPayload(ByteBuf setupFrame) { + this.setupFrame = setupFrame; + } + + @Override + public boolean hasMetadata() { + return FrameHeaderCodec.hasMetadata(setupFrame); + } + + @Override + public ByteBuf sliceMetadata() { + final ByteBuf metadata = SetupFrameCodec.metadata(setupFrame); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + return SetupFrameCodec.data(setupFrame); + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public String metadataMimeType() { + return SetupFrameCodec.metadataMimeType(setupFrame); + } + + @Override + public String dataMimeType() { + return SetupFrameCodec.dataMimeType(setupFrame); + } + + @Override + public int keepAliveInterval() { + return SetupFrameCodec.keepAliveInterval(setupFrame); + } + + @Override + public int keepAliveMaxLifetime() { + return SetupFrameCodec.keepAliveMaxLifetime(setupFrame); + } + + @Override + public int getFlags() { + return FrameHeaderCodec.flags(setupFrame); + } + + @Override + public boolean willClientHonorLease() { + return SetupFrameCodec.honorLease(setupFrame); + } + + @Override + public boolean isResumeEnabled() { + return SetupFrameCodec.resumeEnabled(setupFrame); + } + + @Override + public ByteBuf resumeToken() { + return SetupFrameCodec.resumeToken(setupFrame); + } + + @Override + public ConnectionSetupPayload touch() { + setupFrame.touch(); + return this; + } + + @Override + public ConnectionSetupPayload touch(Object hint) { + setupFrame.touch(hint); + return this; + } + + @Override + protected void deallocate() { + setupFrame.release(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java new file mode 100644 index 000000000..82a02268d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/DefaultRSocketClient.java @@ -0,0 +1,562 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import java.util.AbstractMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CorePublisher; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoOperator; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +/** + * Default implementation of {@link RSocketClient} + * + * @since 1.0.1 + */ +class DefaultRSocketClient extends ResolvingOperator + implements CoreSubscriber, CorePublisher, RSocketClient { + static final Consumer DISCARD_ELEMENTS_CONSUMER = + data -> { + if (data instanceof ReferenceCounted) { + ReferenceCounted referenceCounted = ((ReferenceCounted) data); + if (referenceCounted.refCnt() > 0) { + try { + referenceCounted.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + } + }; + + static final Object ON_DISCARD_KEY; + + static { + Context discardAwareContext = Operators.enableOnDiscard(null, DISCARD_ELEMENTS_CONSUMER); + ON_DISCARD_KEY = discardAwareContext.stream().findFirst().get().getKey(); + } + + final Mono source; + + final Sinks.Empty onDisposeSink; + + volatile Subscription s; + + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(DefaultRSocketClient.class, Subscription.class, "s"); + + DefaultRSocketClient(Mono source) { + this.source = unwrapReconnectMono(source); + this.onDisposeSink = Sinks.empty(); + } + + private Mono unwrapReconnectMono(Mono source) { + return source instanceof ReconnectMono ? ((ReconnectMono) source).getSource() : source; + } + + @Override + public Mono onClose() { + return this.onDisposeSink.asMono(); + } + + @Override + public Mono source() { + return Mono.fromDirect(this); + } + + @Override + public Mono fireAndForget(Mono payloadMono) { + return new RSocketClientMonoOperator<>(this, FrameType.REQUEST_FNF, payloadMono); + } + + @Override + public Mono requestResponse(Mono payloadMono) { + return new RSocketClientMonoOperator<>(this, FrameType.REQUEST_RESPONSE, payloadMono); + } + + @Override + public Flux requestStream(Mono payloadMono) { + return new RSocketClientFluxOperator<>(this, FrameType.REQUEST_STREAM, payloadMono); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new RSocketClientFluxOperator<>(this, FrameType.REQUEST_CHANNEL, payloads); + } + + @Override + public Mono metadataPush(Mono payloadMono) { + return new RSocketClientMonoOperator<>(this, FrameType.METADATA_PUSH, payloadMono); + } + + @Override + @SuppressWarnings("uncheked") + public void subscribe(CoreSubscriber actual) { + final ResolvingOperator.MonoDeferredResolutionOperator inner = + new ResolvingOperator.MonoDeferredResolutionOperator<>(this, actual); + actual.onSubscribe(inner); + + this.observe(inner); + } + + @Override + public void subscribe(Subscriber s) { + subscribe(Operators.toCoreSubscriber(s)); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final RSocket value = this.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + this.doFinally(); + return; + } + + if (value == null) { + this.terminate(new IllegalStateException("Source completed empty")); + } else { + this.complete(value); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + this.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doFinally(); + // terminate upstream which means retryBackoff has exhausted + this.terminate(t); + } + + @Override + public void onNext(RSocket value) { + if (this.s == Operators.cancelledSubscription()) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + // volatile write and check on racing + this.doFinally(); + } + + @Override + protected void doSubscribe() { + this.source.subscribe(this); + } + + @Override + protected void doOnValueResolved(RSocket value) { + value.onClose().subscribe(null, t -> this.invalidate(), this::invalidate); + } + + @Override + protected void doOnValueExpired(RSocket value) { + value.dispose(); + } + + @Override + protected void doOnDispose() { + Operators.terminate(S, this); + final RSocket value = this.value; + if (value != null) { + value.onClose().subscribe(null, onDisposeSink::tryEmitError, onDisposeSink::tryEmitEmpty); + } else { + onDisposeSink.tryEmitEmpty(); + } + } + + static final class FlatMapMain implements CoreSubscriber, Context, Scannable { + + final DefaultRSocketClient parent; + final CoreSubscriber actual; + + final FlattingInner second; + + Subscription s; + + boolean done; + + FlatMapMain( + DefaultRSocketClient parent, CoreSubscriber actual, FrameType requestType) { + this.parent = parent; + this.actual = actual; + this.second = new FlattingInner<>(parent, this, actual, requestType); + } + + @Override + public Context currentContext() { + return this; + } + + @Override + public Stream inners() { + return Stream.of(this.second); + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return this.s; + if (key == Attr.CANCELLED) return this.second.isCancelled(); + if (key == Attr.TERMINATED) return this.done; + + return null; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + this.actual.onSubscribe(this.second); + } + } + + @Override + public void onNext(Payload payload) { + if (this.done) { + if (payload.refCnt() > 0) { + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + return; + } + this.done = true; + + final FlattingInner inner = this.second; + + if (inner.isCancelled()) { + if (payload.refCnt() > 0) { + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + return; + } + + inner.payload = payload; + + if (inner.isCancelled()) { + if (FlattingInner.PAYLOAD.compareAndSet(inner, payload, null)) { + if (payload.refCnt() > 0) { + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + } + return; + } + + this.parent.observe(inner); + } + + @Override + public void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + this.done = true; + + this.actual.onError(t); + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + this.done = true; + + this.actual.onComplete(); + } + + void request(long n) { + this.s.request(n); + } + + void cancel() { + this.s.cancel(); + } + + @Override + @SuppressWarnings("unchecked") + public K get(Object key) { + if (key == ON_DISCARD_KEY) { + return (K) DISCARD_ELEMENTS_CONSUMER; + } + return this.actual.currentContext().get(key); + } + + @Override + public boolean hasKey(Object key) { + if (key == ON_DISCARD_KEY) { + return true; + } + return this.actual.currentContext().hasKey(key); + } + + @Override + public Context put(Object key, Object value) { + return this.actual + .currentContext() + .put(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER) + .put(key, value); + } + + @Override + public Context delete(Object key) { + return this.actual + .currentContext() + .put(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER) + .delete(key); + } + + @Override + public int size() { + return this.actual.currentContext().size() + 1; + } + + @Override + public Stream> stream() { + return Stream.concat( + Stream.of( + new AbstractMap.SimpleImmutableEntry<>(ON_DISCARD_KEY, DISCARD_ELEMENTS_CONSUMER)), + this.actual.currentContext().stream()); + } + } + + static final class FlattingInner extends DeferredResolution { + + final FlatMapMain main; + final FrameType interactionType; + + volatile Payload payload; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater PAYLOAD = + AtomicReferenceFieldUpdater.newUpdater(FlattingInner.class, Payload.class, "payload"); + + FlattingInner( + DefaultRSocketClient parent, + FlatMapMain main, + CoreSubscriber actual, + FrameType interactionType) { + super(parent, actual); + + this.main = main; + this.interactionType = interactionType; + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void accept(RSocket rSocket, Throwable t) { + if (this.isCancelled()) { + return; + } + + Payload payload = PAYLOAD.getAndSet(this, null); + + // means cancelled + if (payload == null) { + return; + } + + if (t != null) { + if (payload.refCnt() > 0) { + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + // ignored + } + } + onError(t); + return; + } + + CorePublisher source; + switch (this.interactionType) { + case REQUEST_FNF: + source = rSocket.fireAndForget(payload); + break; + case REQUEST_RESPONSE: + source = rSocket.requestResponse(payload); + break; + case REQUEST_STREAM: + source = rSocket.requestStream(payload); + break; + case METADATA_PUSH: + source = rSocket.metadataPush(payload); + break; + default: + this.onError(new IllegalStateException("Should never happen")); + return; + } + + source.subscribe((CoreSubscriber) this); + } + + @Override + public void request(long n) { + super.request(n); + this.main.request(n); + } + + public void cancel() { + long state = REQUESTED.getAndSet(this, STATE_CANCELLED); + if (state == STATE_CANCELLED) { + return; + } + + this.main.cancel(); + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + Payload payload = PAYLOAD.getAndSet(this, null); + if (payload != null) { + payload.release(); + } + } + } + } + + static final class RequestChannelInner extends DeferredResolution { + + final FrameType interactionType; + final Publisher upstream; + + RequestChannelInner( + DefaultRSocketClient parent, + Publisher upstream, + CoreSubscriber actual, + FrameType interactionType) { + super(parent, actual); + + this.upstream = upstream; + this.interactionType = interactionType; + } + + @Override + public void accept(RSocket rSocket, Throwable t) { + if (this.isCancelled()) { + return; + } + + if (t != null) { + onError(t); + return; + } + + Flux source; + if (this.interactionType == FrameType.REQUEST_CHANNEL) { + source = rSocket.requestChannel(this.upstream); + } else { + this.onError(new IllegalStateException("Should never happen")); + return; + } + + source.subscribe(this); + } + } + + static class RSocketClientMonoOperator extends MonoOperator { + + final DefaultRSocketClient parent; + final FrameType requestType; + + public RSocketClientMonoOperator( + DefaultRSocketClient parent, FrameType requestType, Mono source) { + super(source); + this.parent = parent; + this.requestType = requestType; + } + + @Override + public void subscribe(CoreSubscriber actual) { + this.source.subscribe(new FlatMapMain(this.parent, actual, this.requestType)); + } + } + + static class RSocketClientFluxOperator> extends Flux { + + final DefaultRSocketClient parent; + final FrameType requestType; + final ST source; + + public RSocketClientFluxOperator( + DefaultRSocketClient parent, FrameType requestType, ST source) { + this.parent = parent; + this.requestType = requestType; + this.source = source; + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (requestType == FrameType.REQUEST_CHANNEL) { + RequestChannelInner inner = + new RequestChannelInner(this.parent, source, actual, requestType); + actual.onSubscribe(inner); + this.parent.observe(inner); + } else { + this.source.subscribe(new FlatMapMain<>(this.parent, actual, this.requestType)); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java new file mode 100644 index 000000000..a5d527f5c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetRequesterMono.java @@ -0,0 +1,295 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class FireAndForgetRequesterMono extends Mono implements Subscription, Scannable { + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(FireAndForgetRequesterMono.class, "state"); + + final Payload payload; + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequestInterceptor requestInterceptor; + + FireAndForgetRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + actual.onSubscribe(this); + + final Payload p = this.payload; + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + actual.onError(e); + return; + } + + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(ut); + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + if (isTerminated(this.state)) { + p.release(); + + if (interceptor != null) { + interceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + + return; + } + + sendReleasingPayload( + streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + actual.onError(e); + return; + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + actual.onComplete(); + } + + @Override + public void request(long n) { + // no ops + } + + @Override + public void cancel() { + markTerminated(STATE, this); + } + + @Override + @Nullable + public Void block(Duration m) { + return block(); + } + + /** + * This method is deliberately non-blocking regardless it is named as `.block`. The main intent to + * keep this method along with the {@link #subscribe()} is to eliminate redundancy which comes + * with a default block method implementation. + */ + @Override + @Nullable + public Void block() { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + throw e; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + throw e; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + throw Exceptions.propagate(e); + } + + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(Exceptions.unwrap(t), FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + throw Exceptions.propagate(t); + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_FNF, + this.mtu, + this.payload, + this.connection, + this.allocator, + true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + throw Exceptions.propagate(e); + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + return null; + } + + @Override + public Object scanUnsafe(Scannable.Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(FireAndForgetMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java new file mode 100644 index 000000000..e76fdf9ed --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FireAndForgetResponderSubscriber.java @@ -0,0 +1,183 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +final class FireAndForgetResponderSubscriber + implements CoreSubscriber, ResponderFrameHandler { + + static final Logger logger = LoggerFactory.getLogger(FireAndForgetResponderSubscriber.class); + + static final FireAndForgetResponderSubscriber INSTANCE = new FireAndForgetResponderSubscriber(); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final RequesterResponderSupport requesterResponderSupport; + final RSocket handler; + final int maxInboundPayloadSize; + + @Nullable final RequestInterceptor requestInterceptor; + + CompositeByteBuf frames; + + private FireAndForgetResponderSubscriber() { + this.streamId = 0; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.requestInterceptor = null; + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = null; + this.payloadDecoder = null; + this.maxInboundPayloadSize = 0; + this.requesterResponderSupport = null; + this.handler = null; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.frames = null; + } + + FireAndForgetResponderSubscriber( + int streamId, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.handler = handler; + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void voidVal) {} + + @Override + public void onError(Throwable t) { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); + } + + logger.debug("Dropped Outbound error", t); + } + + @Override + public void onComplete() { + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, null); + } + } + + @Override + public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = this.frames; + + try { + ReassemblyUtils.addFollowingFrame( + frames, followingFrame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException t) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + this.frames = null; + frames.release(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_FNF, t); + } + + logger.debug("Reassembly has failed", t); + return; + } + + if (!hasFollows) { + this.requesterResponderSupport.remove(this.streamId, this); + this.frames = null; + + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(this.streamId, FrameType.REQUEST_FNF, t); + } + + logger.debug("Reassembly has failed", t); + return; + } + + Mono source = this.handler.fireAndForget(payload); + source.subscribe(this); + } + } + + @Override + public final void handleCancel() { + final CompositeByteBuf frames = this.frames; + if (frames != null) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + this.frames = null; + frames.release(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java new file mode 100644 index 000000000..03b6f9e09 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FragmentationUtils.java @@ -0,0 +1,224 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import reactor.util.annotation.Nullable; + +class FragmentationUtils { + + static final int MIN_MTU_SIZE = 64; + + static final int FRAME_OFFSET = // 9 bytes in total + FrameLengthCodec.FRAME_LENGTH_SIZE // includes encoded frame length bytes size + + FrameHeaderCodec.size(); // includes encoded frame headers info bytes size + static final int FRAME_OFFSET_WITH_METADATA = // 12 bytes in total + FRAME_OFFSET + + FrameLengthCodec.FRAME_LENGTH_SIZE; // include encoded metadata length bytes size + + static final int FRAME_OFFSET_WITH_INITIAL_REQUEST_N = // 13 bytes in total + FRAME_OFFSET + Integer.BYTES; // includes extra space for initialRequestN bytes size + static final int FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N = // 16 bytes in total + FRAME_OFFSET_WITH_METADATA + + Integer.BYTES; // includes extra space for initialRequestN bytes size + + static boolean isFragmentable( + int mtu, ByteBuf data, @Nullable ByteBuf metadata, boolean hasInitialRequestN) { + if (mtu == 0) { + return false; + } + + if (metadata != null) { + int remaining = + mtu + - (hasInitialRequestN + ? FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N + : FRAME_OFFSET_WITH_METADATA); + + return (metadata.readableBytes() + data.readableBytes()) > remaining; + } else { + int remaining = + mtu - (hasInitialRequestN ? FRAME_OFFSET_WITH_INITIAL_REQUEST_N : FRAME_OFFSET); + + return data.readableBytes() > remaining; + } + } + + static ByteBuf encodeFollowsFragment( + ByteBufAllocator allocator, + int mtu, + int streamId, + boolean complete, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length size + int remaining = mtu - FRAME_OFFSET; + + ByteBuf metadataFragment = null; + if (metadata.isReadable()) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + boolean follows = data.isReadable() || metadata.isReadable(); + return PayloadFrameCodec.encode( + allocator, streamId, follows, (!follows && complete), true, metadataFragment, dataFragment); + } + + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + FrameType frameType, + int streamId, + boolean hasMetadata, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length size + int remaining = mtu - FRAME_OFFSET; + + ByteBuf metadataFragment = hasMetadata ? Unpooled.EMPTY_BUFFER : null; + if (hasMetadata) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + if (metadata.isReadable()) { + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + switch (frameType) { + case REQUEST_FNF: + return RequestFireAndForgetFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + case REQUEST_RESPONSE: + return RequestResponseFrameCodec.encode( + allocator, streamId, true, metadataFragment, dataFragment); + // Payload and synthetic types from the responder side + case PAYLOAD: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, false, metadataFragment, dataFragment); + case NEXT: + // see https://github.com/rsocket/rsocket/blob/master/Protocol.md#handling-the-unexpected + // point 7 + case NEXT_COMPLETE: + return PayloadFrameCodec.encode( + allocator, streamId, true, false, true, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); + } + } + + static ByteBuf encodeFirstFragment( + ByteBufAllocator allocator, + int mtu, + long initialRequestN, + FrameType frameType, + int streamId, + boolean hasMetadata, + ByteBuf metadata, + ByteBuf data) { + // subtract the header bytes + frame length bytes + initial requestN bytes + int remaining = mtu - FRAME_OFFSET_WITH_INITIAL_REQUEST_N; + + ByteBuf metadataFragment = hasMetadata ? Unpooled.EMPTY_BUFFER : null; + if (hasMetadata) { + // subtract the metadata frame length + remaining -= FrameLengthCodec.FRAME_LENGTH_SIZE; + if (metadata.isReadable()) { + int r = Math.min(remaining, metadata.readableBytes()); + remaining -= r; + metadataFragment = metadata.readRetainedSlice(r); + } + } + + ByteBuf dataFragment = Unpooled.EMPTY_BUFFER; + try { + if (remaining > 0 && data.isReadable()) { + int r = Math.min(remaining, data.readableBytes()); + dataFragment = data.readRetainedSlice(r); + } + } catch (IllegalReferenceCountException | NullPointerException e) { + if (metadataFragment != null) { + metadataFragment.release(); + } + throw e; + } + + switch (frameType) { + // Requester Side + case REQUEST_STREAM: + return RequestStreamFrameCodec.encode( + allocator, streamId, true, initialRequestN, metadataFragment, dataFragment); + case REQUEST_CHANNEL: + return RequestChannelFrameCodec.encode( + allocator, streamId, true, false, initialRequestN, metadataFragment, dataFragment); + default: + throw new IllegalStateException("unsupported fragment type: " + frameType); + } + } + + static int assertMtu(int mtu) { + if (mtu > 0 && mtu < MIN_MTU_SIZE || mtu < 0) { + String msg = + String.format( + "The smallest allowed mtu size is %d bytes, provided: %d", MIN_MTU_SIZE, mtu); + throw new IllegalArgumentException(msg); + } else { + return mtu; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java b/rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java new file mode 100644 index 000000000..6d1ee1b09 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/FrameHandler.java @@ -0,0 +1,31 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; + +interface FrameHandler { + + void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload); + + void handleError(Throwable t); + + void handleComplete(); + + void handleCancel(); + + void handleRequestN(long n); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java b/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java new file mode 100644 index 000000000..03ab7c257 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LeasePermitHandler.java @@ -0,0 +1,20 @@ +package io.rsocket.core; + +/** Handler which enables async lease permits issuing */ +interface LeasePermitHandler { + + /** + * Called by {@link RequesterLeaseTracker} when there is an available lease + * + * @return {@code true} to indicate that lease permit was consumed successfully + */ + boolean handlePermit(); + + /** + * Called by {@link RequesterLeaseTracker} when there are no lease permit available at the moment + * and the list of awaiting {@link LeasePermitHandler} reached the configured limit + * + * @param t associated lease permit rejection exception + */ + void handlePermitError(Throwable t); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java b/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java new file mode 100644 index 000000000..ad4b36e3a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LeaseSpec.java @@ -0,0 +1,44 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.rsocket.lease.LeaseSender; +import reactor.core.publisher.Flux; + +public final class LeaseSpec { + + LeaseSender sender = Flux::never; + int maxPendingRequests = 256; + + LeaseSpec() {} + + public LeaseSpec sender(LeaseSender sender) { + this.sender = sender; + return this; + } + + /** + * Setup the maximum queued requests waiting for lease to be available. The default value is 256 + * + * @param maxPendingRequests if set to 0 the requester will terminate the request immediately if + * no leases is available + */ + public LeaseSpec maxPendingRequests(int maxPendingRequests) { + this.maxPendingRequests = maxPendingRequests; + return this; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java new file mode 100644 index 000000000..7b5d8f6c2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/LoggingDuplexConnection.java @@ -0,0 +1,72 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.FrameUtil; +import java.net.SocketAddress; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +class LoggingDuplexConnection implements DuplexConnection { + + private static final Logger LOGGER = LoggerFactory.getLogger("io.rsocket.FrameLogger"); + + final DuplexConnection source; + + LoggingDuplexConnection(DuplexConnection source) { + this.source = source; + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + LOGGER.debug("sending -> " + FrameUtil.toString(frame)); + + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + LOGGER.debug("sending -> " + e.getClass().getSimpleName() + ": " + e.getMessage()); + + source.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return source + .receive() + .doOnNext(frame -> LOGGER.debug("receiving -> " + FrameUtil.toString(frame))); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + static DuplexConnection wrapIfEnabled(DuplexConnection source) { + if (LOGGER.isDebugEnabled()) { + return new LoggingDuplexConnection(source); + } + + return source; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java new file mode 100644 index 000000000..e2512e995 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushRequesterMono.java @@ -0,0 +1,190 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValidMetadata; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.MetadataPushFrameCodec; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class MetadataPushRequesterMono extends Mono implements Scannable { + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(MetadataPushRequesterMono.class, "state"); + + final ByteBufAllocator allocator; + final Payload payload; + final int maxFrameLength; + final DuplexConnection connection; + + MetadataPushRequesterMono(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.connection = requesterResponderSupport.getDuplexConnection(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + Operators.error( + actual, new IllegalStateException("MetadataPushMono allows only a single Subscriber")); + return; + } + + final Payload p = this.payload; + final ByteBuf metadata; + try { + final boolean hasMetadata = p.hasMetadata(); + metadata = p.metadata(); + if (!hasMetadata) { + lazyTerminate(STATE, this); + p.release(); + Operators.error( + actual, + new IllegalArgumentException("Metadata push should have metadata field present")); + return; + } + if (!isValidMetadata(this.maxFrameLength, metadata)) { + lazyTerminate(STATE, this); + p.release(); + Operators.error( + actual, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength))); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = metadata.retainedSlice(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + Operators.error(actual, e); + return; + } + + try { + p.release(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + metadataRetainedSlice.release(); + Operators.error(actual, e); + return; + } + + final ByteBuf requestFrame = + MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); + this.connection.sendFrame(0, requestFrame); + + Operators.complete(actual); + } + + @Override + @Nullable + public Void block(Duration m) { + return block(); + } + + /** + * This method is deliberately non-blocking regardless it is named as `.block`. The main intent to + * keep this method along with the {@link #subscribe()} is to eliminate redundancy which comes + * with a default block method implementation. + */ + @Override + @Nullable + public Void block() { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + throw new IllegalStateException("MetadataPushMono allows only a single Subscriber"); + } + + final Payload p = this.payload; + final ByteBuf metadata; + try { + final boolean hasMetadata = p.hasMetadata(); + metadata = p.metadata(); + if (!hasMetadata) { + lazyTerminate(STATE, this); + p.release(); + throw new IllegalArgumentException("Metadata push should have metadata field present"); + } + if (!isValidMetadata(this.maxFrameLength, metadata)) { + lazyTerminate(STATE, this); + p.release(); + throw new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + throw e; + } + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = metadata.retainedSlice(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + throw e; + } + + try { + p.release(); + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + metadataRetainedSlice.release(); + throw e; + } + + final ByteBuf requestFrame = + MetadataPushFrameCodec.encode(this.allocator, metadataRetainedSlice); + this.connection.sendFrame(0, requestFrame); + + return null; + } + + @Override + public Object scanUnsafe(Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(MetadataPushMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java new file mode 100644 index 000000000..4c69934e8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/MetadataPushResponderSubscriber.java @@ -0,0 +1,45 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; + +final class MetadataPushResponderSubscriber implements CoreSubscriber { + static final Logger logger = LoggerFactory.getLogger(MetadataPushResponderSubscriber.class); + + static final MetadataPushResponderSubscriber INSTANCE = new MetadataPushResponderSubscriber(); + + private MetadataPushResponderSubscriber() {} + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void voidVal) {} + + @Override + public void onError(Throwable t) { + logger.debug("Dropped error", t); + } + + @Override + public void onComplete() {} +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java new file mode 100644 index 000000000..6ece319c9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/PayloadValidationUtils.java @@ -0,0 +1,76 @@ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_INITIAL_REQUEST_N; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; + +final class PayloadValidationUtils { + static final String INVALID_PAYLOAD_ERROR_MESSAGE = + "The payload is too big to be send as a single frame with a max frame length %s. Consider enabling fragmentation."; + + static boolean isValid(int mtu, int maxFrameLength, Payload payload, boolean hasInitialRequestN) { + + if (mtu > 0) { + return true; + } + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf data = payload.data(); + + int unitSize; + if (hasMetadata) { + final ByteBuf metadata = payload.metadata(); + unitSize = + (hasInitialRequestN + ? FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N + : FRAME_OFFSET_WITH_METADATA) + + metadata.readableBytes() + + // metadata payload bytes + data.readableBytes(); // data payload bytes + } else { + unitSize = + (hasInitialRequestN ? FRAME_OFFSET_WITH_INITIAL_REQUEST_N : FRAME_OFFSET) + + data.readableBytes(); // data payload bytes + } + + return unitSize <= maxFrameLength; + } + + static boolean isValidMetadata(int maxFrameLength, ByteBuf metadata) { + return FRAME_OFFSET + metadata.readableBytes() <= maxFrameLength; + } + + static void assertValidateSetup(int maxFrameLength, int maxInboundPayloadSize, int mtu) { + + if (maxFrameLength > FRAME_LENGTH_MASK) { + throw new IllegalArgumentException( + "Configured maxFrameLength[" + + maxFrameLength + + "] exceeds maxFrameLength limit " + + FRAME_LENGTH_MASK); + } + + if (maxFrameLength > maxInboundPayloadSize) { + throw new IllegalArgumentException( + "Configured maxFrameLength[" + + maxFrameLength + + "] exceeds maxPayloadSize[" + + maxInboundPayloadSize + + "]"); + } + + if (mtu != 0 && mtu > maxFrameLength) { + throw new IllegalArgumentException( + "Configured maximumTransmissionUnit[" + + mtu + + "] exceeds configured maxFrameLength[" + + maxFrameLength + + "]"); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java new file mode 100644 index 000000000..32e3c229d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketClient.java @@ -0,0 +1,153 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import sun.reflect.generics.reflectiveObjects.NotImplementedException; + +/** + * Contract for performing RSocket requests. + * + *

{@link RSocketClient} differs from {@link RSocket} in a number of ways: + * + *

    + *
  • {@code RSocket} represents a "live" connection that is transient and needs to be obtained + * typically from a {@code Mono} source via {@code flatMap} or block. By contrast, + * {@code RSocketClient} is a higher level layer that contains such a {@link #source() source} + * of connections and transparently obtains and re-obtains a shared connection as needed when + * requests are made concurrently. That means an {@code RSocketClient} can simply be created + * once, even before a connection is established, and shared as a singleton across multiple + * places as you would with any other client. + *
  • For request input {@code RSocket} accepts an instance of {@code Payload} and does not allow + * more than one subscription per request because there is no way to safely re-use that input. + * By contrast {@code RSocketClient} accepts {@code Publisher} and allow + * re-subscribing which repeats the request. + *
  • {@code RSocket} can be used for sending and it can also be implemented for receiving. By + * contrast {@code RSocketClient} is used only for sending, typically from the client side + * which allows obtaining and re-obtaining connections from a source as needed. However it can + * also be used from the server side by {@link #from(RSocket) wrapping} the "live" {@code + * RSocket} for a given connection. + *
+ * + *

The example below shows how to create an {@code RSocketClient}: + * + *

{@code
+ * Mono source =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ *
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ * + *

The below configures retry logic to use when a shared {@code RSocket} connection is obtained: + * + *

{@code
+ * Mono source =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ *
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ * + * @since 1.1 + * @see io.rsocket.loadbalance.LoadbalanceRSocketClient + */ +public interface RSocketClient extends Closeable { + + /** + * Connect to the remote rsocket endpoint, if not yet connected. This method is a shortcut for + * {@code RSocketClient#source().subscribe()}. + * + * @return {@code true} if an attempt to connect was triggered or if already connected, or {@code + * false} if the client is terminated. + */ + default boolean connect() { + throw new NotImplementedException(); + } + + default Mono onClose() { + return Mono.error(new NotImplementedException()); + } + + /** Return the underlying source used to obtain a shared {@link RSocket} connection. */ + Mono source(); + + /** + * Perform a Fire-and-Forget interaction via {@link RSocket#fireAndForget(Payload)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Mono fireAndForget(Mono payloadMono); + + /** + * Perform a Request-Response interaction via {@link RSocket#requestResponse(Payload)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Mono requestResponse(Mono payloadMono); + + /** + * Perform a Request-Stream interaction via {@link RSocket#requestStream(Payload)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Flux requestStream(Mono payloadMono); + + /** + * Perform a Request-Channel interaction via {@link RSocket#requestChannel(Publisher)}. Allows + * multiple subscriptions and performs a request per subscriber. + */ + Flux requestChannel(Publisher payloads); + + /** + * Perform a Metadata Push via {@link RSocket#metadataPush(Payload)}. Allows multiple + * subscriptions and performs a request per subscriber. + */ + Mono metadataPush(Mono payloadMono); + + /** + * Create an {@link RSocketClient} that obtains shared connections as needed, when requests are + * made, from the given {@code Mono} source. + * + * @param source the source for connections, typically prepared via {@link RSocketConnector}. + * @return the created client instance + */ + static RSocketClient from(Mono source) { + return new DefaultRSocketClient(source); + } + + /** + * Adapt the given {@link RSocket} to use as {@link RSocketClient}. This is useful to wrap the + * sending {@code RSocket} in a server. + * + *

Note: unlike an {@code RSocketClient} created via {@link + * RSocketClient#from(Mono)}, the instance returned from this factory method can only perform + * requests for as long as the given {@code RSocket} remains "live". + * + * @param rsocket the {@code RSocket} to perform requests with + * @return the created client instance + */ + static RSocketClient from(RSocket rsocket) { + return new RSocketClientAdapter(rsocket); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java new file mode 100644 index 000000000..ae8b7da97 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketClientAdapter.java @@ -0,0 +1,88 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * Simple adapter from {@link RSocket} to {@link RSocketClient}. This is useful in code that needs + * to deal with both in the same way. When connecting to a server, typically {@link RSocketClient} + * is expected to be used, but in a responder (client or server), it is necessary to interact with + * {@link RSocket} to make requests to the remote end. + * + * @since 1.1 + */ +class RSocketClientAdapter implements RSocketClient { + + private final RSocket rsocket; + + public RSocketClientAdapter(RSocket rsocket) { + this.rsocket = rsocket; + } + + public RSocket rsocket() { + return rsocket; + } + + @Override + public boolean connect() { + throw new UnsupportedOperationException("Connect does not apply to a server side RSocket"); + } + + @Override + public Mono source() { + return Mono.just(rsocket); + } + + @Override + public Mono onClose() { + return rsocket.onClose(); + } + + @Override + public Mono fireAndForget(Mono payloadMono) { + return payloadMono.flatMap(rsocket::fireAndForget); + } + + @Override + public Mono requestResponse(Mono payloadMono) { + return payloadMono.flatMap(rsocket::requestResponse); + } + + @Override + public Flux requestStream(Mono payloadMono) { + return payloadMono.flatMapMany(rsocket::requestStream); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return rsocket.requestChannel(payloads); + } + + @Override + public Mono metadataPush(Mono payloadMono) { + return payloadMono.flatMap(rsocket::metadataPush); + } + + @Override + public void dispose() { + rsocket.dispose(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java new file mode 100644 index 000000000..de494c4e3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketConnector.java @@ -0,0 +1,746 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.assertMtu; +import static io.rsocket.core.PayloadValidationUtils.assertValidateSetup; +import static io.rsocket.core.ReassemblyUtils.assertInboundPayloadSize; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.lease.TrackingLeaseSender; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.resume.ClientRSocketSession; +import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumableFramesStore; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +/** + * The main class to use to establish a connection to an RSocket server. + * + *

For using TCP using default settings: + * + *

{@code
+ * import io.rsocket.transport.netty.client.TcpClientTransport;
+ *
+ * Mono source =
+ *         RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000));
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ * + *

To customize connection settings before connecting: + * + *

{@code
+ * Mono source =
+ *         RSocketConnector.create()
+ *                 .metadataMimeType("message/x.rsocket.composite-metadata.v0")
+ *                 .dataMimeType("application/cbor")
+ *                 .connect(TcpClientTransport.create("localhost", 7000));
+ * RSocketClient client = RSocketClient.from(source);
+ * }
+ */ +public class RSocketConnector { + private static final String CLIENT_TAG = "client"; + + private static final BiConsumer INVALIDATE_FUNCTION = + (r, i) -> r.onClose().subscribe(null, __ -> i.invalidate(), i::invalidate); + + private Mono setupPayloadMono = Mono.empty(); + private String metadataMimeType = "application/binary"; + private String dataMimeType = "application/binary"; + private Duration keepAliveInterval = Duration.ofSeconds(20); + private Duration keepAliveMaxLifeTime = Duration.ofSeconds(90); + + @Nullable private SocketAcceptor acceptor; + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + + private Retry retrySpec; + private Resume resume; + + @Nullable private Consumer leaseConfigurer; + + private int mtu = 0; + private int maxInboundPayloadSize = Integer.MAX_VALUE; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + + private RSocketConnector() {} + + /** + * Static factory method to create an {@code RSocketConnector} instance and customize default + * settings before connecting. To connect only, use {@link #connectWith(ClientTransport)}. + */ + public static RSocketConnector create() { + return new RSocketConnector(); + } + + /** + * Static factory method to connect with default settings, effectively a shortcut for: + * + *
+   * RSocketConnector.create().connect(transport);
+   * 
+ * + * @param transport the transport of choice to connect with + * @return a {@code Mono} with the connected RSocket + */ + public static Mono connectWith(ClientTransport transport) { + return RSocketConnector.create().connect(() -> transport); + } + + /** + * Provide a {@code Mono} from which to obtain the {@code Payload} for the initial SETUP frame. + * Data and metadata should be formatted according to the MIME types specified via {@link + * #dataMimeType(String)} and {@link #metadataMimeType(String)}. + * + * @param setupPayloadMono the payload with data and/or metadata for the {@code SETUP} frame. + * @return the same instance for method chaining + * @since 1.0.2 + * @see SETUP + * Frame + */ + public RSocketConnector setupPayload(Mono setupPayloadMono) { + this.setupPayloadMono = setupPayloadMono; + return this; + } + + /** + * Variant of {@link #setupPayload(Mono)} that accepts a {@code Payload} instance. + * + *

Note: if the given payload is {@link io.rsocket.util.ByteBufPayload}, it is copied to a + * {@link DefaultPayload} and released immediately. This ensures it can re-used to obtain a + * connection more than once. + * + * @param payload the payload with data and/or metadata for the {@code SETUP} frame. + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector setupPayload(Payload payload) { + if (payload instanceof DefaultPayload) { + this.setupPayloadMono = Mono.just(payload); + } else { + this.setupPayloadMono = Mono.just(DefaultPayload.create(Objects.requireNonNull(payload))); + payload.release(); + } + return this; + } + + /** + * Set the MIME type to use for formatting payload data on the established connection. This is set + * in the initial {@code SETUP} frame sent to the server. + * + *

By default this is set to {@code "application/binary"}. + * + * @param dataMimeType the MIME type to be used for payload data + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector dataMimeType(String dataMimeType) { + this.dataMimeType = Objects.requireNonNull(dataMimeType); + return this; + } + + /** + * Set the MIME type to use for formatting payload metadata on the established connection. This is + * set in the initial {@code SETUP} frame sent to the server. + * + *

For metadata encoding, consider using one of the following encoders: + * + *

    + *
  • {@link io.rsocket.metadata.CompositeMetadataCodec Composite Metadata} + *
  • {@link io.rsocket.metadata.TaggingMetadataCodec Routing} + *
  • {@link io.rsocket.metadata.AuthMetadataCodec Authentication} + *
+ * + *

For more on the above metadata formats, see the corresponding protocol extensions + * + *

By default this is set to {@code "application/binary"}. + * + * @param metadataMimeType the MIME type to be used for payload metadata + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector metadataMimeType(String metadataMimeType) { + this.metadataMimeType = Objects.requireNonNull(metadataMimeType); + return this; + } + + /** + * Set the "Time Between {@code KEEPALIVE} Frames" which is how frequently {@code KEEPALIVE} + * frames should be emitted, and the "Max Lifetime" which is how long to allow between {@code + * KEEPALIVE} frames from the remote end before concluding that connectivity is lost. Both + * settings are specified in the initial {@code SETUP} frame sent to the server. The spec mentions + * the following: + * + *

    + *
  • For server-to-server connections, a reasonable time interval between client {@code + * KEEPALIVE} frames is 500ms. + *
  • For mobile-to-server connections, the time interval between client {@code KEEPALIVE} + * frames is often {@code >} 30,000ms. + *
+ * + *

By default these are set to 20 seconds and 90 seconds respectively. + * + * @param interval how frequently to emit KEEPALIVE frames + * @param maxLifeTime how long to allow between {@code KEEPALIVE} frames from the remote end + * before assuming that connectivity is lost; the value should be generous and allow for + * multiple missed {@code KEEPALIVE} frames. + * @return the same instance for method chaining + * @see SETUP + * Frame + */ + public RSocketConnector keepAlive(Duration interval, Duration maxLifeTime) { + if (!interval.negated().isNegative()) { + throw new IllegalArgumentException("`interval` for keepAlive must be > 0"); + } + if (!maxLifeTime.negated().isNegative()) { + throw new IllegalArgumentException("`maxLifeTime` for keepAlive must be > 0"); + } + this.keepAliveInterval = interval; + this.keepAliveMaxLifeTime = maxLifeTime; + return this; + } + + /** + * Configure interception at one of the following levels: + * + *

    + *
  • Transport level + *
  • At the level of accepting new connections + *
  • Performing requests + *
  • Responding to requests + *
+ * + * @param configurer a configurer to customize interception with. + * @return the same instance for method chaining + * @see io.rsocket.plugins.LimitRateInterceptor + */ + public RSocketConnector interceptors(Consumer configurer) { + configurer.accept(this.interceptors); + return this; + } + + /** + * Configure a client-side {@link SocketAcceptor} for responding to requests from the server. + * + *

A full-form example with access to the {@code SETUP} frame and the "sending" RSocket (the + * same as the one returned from {@link #connect(ClientTransport)}): + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor((setup, sendingRSocket) -> Mono.just(new RSocket() {...}))
+   *             .connect(transport);
+   * }
+ * + *

A shortcut example with just the handling RSocket: + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor(SocketAcceptor.with(new RSocket() {...})))
+   *             .connect(transport);
+   * }
+ * + *

A shortcut example handling only request-response: + * + *

{@code
+   * Mono rsocketMono =
+   *     RSocketConnector.create()
+   *             .acceptor(SocketAcceptor.forRequestResponse(payload -> ...))
+   *             .connect(transport);
+   * }
+ * + *

By default, {@code new RSocket(){}} is used which rejects all requests from the server with + * {@link UnsupportedOperationException}. + * + * @param acceptor the acceptor to use for responding to server requests + * @return the same instance for method chaining + */ + public RSocketConnector acceptor(SocketAcceptor acceptor) { + this.acceptor = acceptor; + return this; + } + + /** + * When this is enabled, the connect methods of this class return a special {@code Mono} + * that maintains a single, shared {@code RSocket} for all subscribers: + * + *

{@code
+   * Mono rsocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = rsocketMono.block();
+   *  RSocket r2 = rsocketMono.block();
+   *
+   *  assert r1 == r2;
+   * }
+ * + *

The {@code RSocket} remains cached until the connection is lost and after that, new attempts + * to subscribe or re-subscribe trigger a reconnect and result in a new shared {@code RSocket}: + * + *

{@code
+   * Mono rsocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  RSocket r1 = rsocketMono.block();
+   *  RSocket r2 = rsocketMono.block();
+   *
+   *  r1.dispose();
+   *
+   *  RSocket r3 = rsocketMono.block();
+   *  RSocket r4 = rsocketMono.block();
+   *
+   *  assert r1 == r2;
+   *  assert r3 == r4;
+   *  assert r1 != r3;
+   *
+   * }
+ * + *

Downstream subscribers for individual requests still need their own retry logic to determine + * if or when failed requests should be retried which in turn triggers the shared reconnect: + * + *

{@code
+   * Mono rocketMono =
+   *   RSocketConnector.create()
+   *           .reconnect(Retry.fixedDelay(3, Duration.ofSeconds(1)))
+   *           .connect(transport);
+   *
+   *  rsocketMono.flatMap(rsocket -> rsocket.requestResponse(...))
+   *           .retryWhen(Retry.fixedDelay(1, Duration.ofSeconds(5)))
+   *           .subscribe()
+   * }
+ * + *

Note: this feature is mutually exclusive with {@link #resume(Resume)}. If + * both are enabled, "resume" takes precedence. Consider using "reconnect" when the server does + * not have "resume" enabled or supported, or when you don't need to incur the overhead of saving + * in-flight frames to be potentially replayed after a reconnect. + * + *

By default this is not enabled in which case a new connection is obtained per subscriber. + * + * @param retry a retry spec that declares the rules for reconnecting + * @return the same instance for method chaining + */ + public RSocketConnector reconnect(Retry retry) { + this.retrySpec = Objects.requireNonNull(retry); + return this; + } + + /** + * Enables the Resume capability of the RSocket protocol where if the client gets disconnected, + * the connection is re-acquired and any interrupted streams are resumed automatically. For this + * to work the server must also support and have the Resume capability enabled. + * + *

See {@link Resume} for settings to customize the Resume capability. + * + *

Note: this feature is mutually exclusive with {@link #reconnect(Retry)}. If + * both are enabled, "resume" takes precedence. Consider using "reconnect" when the server does + * not have "resume" enabled or supported, or when you don't need to incur the overhead of saving + * in-flight frames to be potentially replayed after a reconnect. + * + *

By default this is not enabled. + * + * @param resume configuration for the Resume capability + * @return the same instance for method chaining + * @see Resuming + * Operation + */ + public RSocketConnector resume(Resume resume) { + this.resume = resume; + return this; + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. + * + *

Example usage: + * + *

{@code
+   * Mono rocketMono =
+   *         RSocketConnector.create()
+   *                         .lease()
+   *                         .connect(transport);
+   * }
+ * + *

By default this is not enabled. + * + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketConnector lease() { + return lease((config -> {})); + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. + * + *

Example usage: + * + *

{@code
+   * Mono rocketMono =
+   *         RSocketConnector.create()
+   *                         .lease(spec -> spec.maxPendingRequests(128))
+   *                         .connect(transport);
+   * }
+ * + *

By default this is not enabled. + * + * @param leaseConfigurer consumer which accepts {@link LeaseSpec} and use it for configuring + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketConnector lease(Consumer leaseConfigurer) { + this.leaseConfigurer = leaseConfigurer; + return this; + } + + /** + * When this is set, frames reassembler control maximum payload size which can be reassembled. + * + *

By default this is not set in which case maximum reassembled payloads size is not + * controlled. + * + * @param maxInboundPayloadSize the threshold size for reassembly, must no be less than 64 bytes. + * Please note, {@code maxInboundPayloadSize} must always be greater or equal to {@link + * io.rsocket.transport.Transport#maxFrameLength()}, otherwise inbound frame can exceed the + * {@code maxInboundPayloadSize} + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketConnector maxInboundPayloadSize(int maxInboundPayloadSize) { + this.maxInboundPayloadSize = assertInboundPayloadSize(maxInboundPayloadSize); + return this; + } + + /** + * When this is set, frames larger than the given maximum transmission unit (mtu) size value are + * broken down into fragments to fit that size. + * + *

By default this is not set in which case payloads are sent whole up to the maximum frame + * size of 16,777,215 bytes. + * + * @param mtu the threshold size for fragmentation, must be no less than 64 + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketConnector fragment(int mtu) { + this.mtu = assertMtu(mtu); + return this; + } + + /** + * Configure the {@code PayloadDecoder} used to create {@link Payload}'s from incoming raw frame + * buffers. The following decoders are available: + * + *

    + *
  • {@link PayloadDecoder#DEFAULT} -- the data and metadata are independent copies of the + * underlying frame {@link ByteBuf} + *
  • {@link PayloadDecoder#ZERO_COPY} -- the data and metadata are retained slices of the + * underlying {@link ByteBuf}. That's more efficient but requires careful tracking and + * {@link Payload#release() release} of the payload when no longer needed. + *
+ * + *

By default this is set to {@link PayloadDecoder#DEFAULT} in which case data and metadata are + * copied and do not need to be tracked and released. + * + * @param decoder the decoder to use + * @return the same instance for method chaining + */ + public RSocketConnector payloadDecoder(PayloadDecoder decoder) { + Objects.requireNonNull(decoder); + this.payloadDecoder = decoder; + return this; + } + + /** + * Connect with the given transport and obtain a live {@link RSocket} to use for making requests. + * Each subscriber to the returned {@code Mono} receives a new connection, if neither {@link + * #reconnect(Retry) reconnect} nor {@link #resume(Resume)} are enabled. + * + *

The following transports are available through additional RSocket Java modules: + * + *

    + *
  • {@link io.rsocket.transport.netty.client.TcpClientTransport TcpClientTransport} via + * {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.netty.client.WebsocketClientTransport + * WebsocketClientTransport} via {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.local.LocalClientTransport LocalClientTransport} via {@code + * rsocket-transport-local} + *
+ * + * @param transport the transport of choice to connect with + * @return a {@code Mono} with the connected RSocket + */ + public Mono connect(ClientTransport transport) { + return connect(() -> transport); + } + + /** + * Variant of {@link #connect(ClientTransport)} with a {@link Supplier} for the {@code + * ClientTransport}. + * + *

// TODO: when to use? + * + * @param transportSupplier supplier for the transport to connect with + * @return a {@code Mono} with the connected RSocket + */ + public Mono connect(Supplier transportSupplier) { + return Mono.fromSupplier(transportSupplier) + .flatMap( + ct -> { + int maxFrameLength = ct.maxFrameLength(); + + Mono connectionMono = + Mono.fromCallable( + () -> { + assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); + return ct; + }) + .flatMap(transport -> transport.connect()) + .map( + sourceConnection -> + interceptors.initConnection( + DuplexConnectionInterceptor.Type.SOURCE, sourceConnection)) + .map(source -> LoggingDuplexConnection.wrapIfEnabled(source)); + + return connectionMono + .flatMap( + connection -> + setupPayloadMono + .defaultIfEmpty(EmptyPayload.INSTANCE) + .map(setupPayload -> Tuples.of(connection, setupPayload)) + .doOnError(ex -> connection.dispose()) + .doOnCancel(connection::dispose)) + .flatMap( + tuple2 -> { + DuplexConnection sourceConnection = tuple2.getT1(); + Payload setupPayload = tuple2.getT2(); + boolean leaseEnabled = leaseConfigurer != null; + boolean resumeEnabled = resume != null; + // TODO: add LeaseClientSetup + ClientSetup clientSetup = new DefaultClientSetup(); + ByteBuf resumeToken; + + if (resumeEnabled) { + resumeToken = resume.getTokenSupplier().get(); + } else { + resumeToken = Unpooled.EMPTY_BUFFER; + } + + ByteBuf setupFrame = + SetupFrameCodec.encode( + sourceConnection.alloc(), + leaseEnabled, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + resumeToken, + metadataMimeType, + dataMimeType, + setupPayload); + + sourceConnection.sendFrame(0, setupFrame.retainedSlice()); + + return clientSetup + .init(sourceConnection) + .flatMap( + tuple -> { + // should be used if lease setup sequence; + // See: + // https://github.com/rsocket/rsocket/blob/master/Protocol.md#sequences-with-lease + final ByteBuf serverResponse = tuple.getT1(); + final DuplexConnection clientServerConnection = tuple.getT2(); + final KeepAliveHandler keepAliveHandler; + final DuplexConnection wrappedConnection; + final InitializingInterceptorRegistry interceptors = + this.interceptors; + + if (resumeEnabled) { + final ResumableFramesStore resumableFramesStore = + resume.getStoreFactory(CLIENT_TAG).apply(resumeToken); + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + CLIENT_TAG, + resumeToken, + clientServerConnection, + resumableFramesStore); + final ResumableClientSetup resumableClientSetup = + new ResumableClientSetup(); + final ClientRSocketSession session = + new ClientRSocketSession( + resumeToken, + resumableDuplexConnection, + connectionMono, + resumableClientSetup::init, + resumableFramesStore, + resume.getSessionDuration(), + resume.getRetry(), + resume.isCleanupStoreOnKeepAlive()); + keepAliveHandler = + new KeepAliveHandler.ResumableKeepAliveHandler( + resumableDuplexConnection, session, session); + wrappedConnection = resumableDuplexConnection; + } else { + keepAliveHandler = + new KeepAliveHandler.DefaultKeepAliveHandler(); + wrappedConnection = clientServerConnection; + } + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer( + wrappedConnection, interceptors, true); + + final LeaseSpec leases; + final RequesterLeaseTracker requesterLeaseTracker; + if (leaseEnabled) { + leases = new LeaseSpec(); + leaseConfigurer.accept(leases); + requesterLeaseTracker = + new RequesterLeaseTracker( + CLIENT_TAG, leases.maxPendingRequests); + } else { + leases = null; + requesterLeaseTracker = null; + } + + final Sinks.Empty requesterOnAllClosedSink = + Sinks.unsafe().empty(); + final Sinks.Empty responderOnAllClosedSink = + Sinks.unsafe().empty(); + + RSocket rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), + payloadDecoder, + StreamIdSupplier.clientSupplier(), + mtu, + maxFrameLength, + maxInboundPayloadSize, + (int) keepAliveInterval.toMillis(), + (int) keepAliveMaxLifeTime.toMillis(), + keepAliveHandler, + interceptors::initRequesterRequestInterceptor, + requesterLeaseTracker, + requesterOnAllClosedSink, + Mono.whenDelayError( + responderOnAllClosedSink.asMono(), + requesterOnAllClosedSink.asMono())); + + RSocket wrappedRSocketRequester = + interceptors.initRequester(rSocketRequester); + + SocketAcceptor acceptor = + this.acceptor != null + ? this.acceptor + : SocketAcceptor.with(new RSocket() {}); + + ConnectionSetupPayload setup = + new DefaultConnectionSetupPayload(setupFrame); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setup, wrappedRSocketRequester) + .map( + rSocketHandler -> { + RSocket wrappedRSocketHandler = + interceptors.initResponder(rSocketHandler); + + ResponderLeaseTracker responderLeaseTracker = + leaseEnabled + ? new ResponderLeaseTracker( + CLIENT_TAG, + wrappedConnection, + leases.sender) + : null; + + RSocket rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + wrappedRSocketHandler, + payloadDecoder, + responderLeaseTracker, + mtu, + maxFrameLength, + maxInboundPayloadSize, + leaseEnabled + && leases.sender + instanceof TrackingLeaseSender + ? rSocket -> + interceptors + .initResponderRequestInterceptor( + rSocket, + (RequestInterceptor) + leases.sender) + : interceptors + ::initResponderRequestInterceptor, + responderOnAllClosedSink); + + return wrappedRSocketRequester; + }) + .doFinally(signalType -> setup.release()); + }); + }); + }) + .as( + source -> { + if (retrySpec != null) { + return new ReconnectMono<>( + source.retryWhen(retrySpec), Disposable::dispose, INVALIDATE_FUNCTION); + } else { + return source; + } + }); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java new file mode 100644 index 000000000..b8a9c00ff --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketRequester.java @@ -0,0 +1,445 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.keepalive.KeepAliveSupport.ClientKeepAliveSupport; + +import io.netty.buffer.ByteBuf; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.keepalive.KeepAliveFramesAcceptor; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.plugins.RequestInterceptor; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +/** + * Requester Side of a RSocket socket. Sends {@link ByteBuf}s to a {@link RSocketResponder} of peer + */ +class RSocketRequester extends RequesterResponderSupport implements RSocket { + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketRequester.class); + + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + + static { + CLOSED_CHANNEL_EXCEPTION.setStackTrace(new StackTraceElement[0]); + } + + private volatile Throwable terminationError; + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RSocketRequester.class, Throwable.class, "terminationError"); + + @Nullable private final RequesterLeaseTracker requesterLeaseTracker; + + private final Sinks.Empty onThisSideClosedSink; + private final Mono onAllClosed; + private final KeepAliveFramesAcceptor keepAliveFramesAcceptor; + + RSocketRequester( + DuplexConnection connection, + PayloadDecoder payloadDecoder, + StreamIdSupplier streamIdSupplier, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + int keepAliveTickPeriod, + int keepAliveAckTimeout, + @Nullable KeepAliveHandler keepAliveHandler, + Function requestInterceptorFunction, + @Nullable RequesterLeaseTracker requesterLeaseTracker, + Sinks.Empty onThisSideClosedSink, + Mono onAllClosed) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + streamIdSupplier, + requestInterceptorFunction); + + this.requesterLeaseTracker = requesterLeaseTracker; + this.onThisSideClosedSink = onThisSideClosedSink; + this.onAllClosed = onAllClosed; + + // DO NOT Change the order here. The Send processor must be subscribed to before receiving + connection.onClose().subscribe(null, this::tryShutdown, this::tryShutdown); + + connection.receive().subscribe(this::handleIncomingFrames, e -> {}); + + if (keepAliveTickPeriod != 0 && keepAliveHandler != null) { + KeepAliveSupport keepAliveSupport = + new ClientKeepAliveSupport(this.getAllocator(), keepAliveTickPeriod, keepAliveAckTimeout); + this.keepAliveFramesAcceptor = + keepAliveHandler.start( + keepAliveSupport, + (keepAliveFrame) -> connection.sendFrame(0, keepAliveFrame), + this::tryTerminateOnKeepAlive); + } else { + keepAliveFramesAcceptor = null; + } + } + + @Override + public Mono fireAndForget(Payload payload) { + if (this.requesterLeaseTracker == null) { + return new FireAndForgetRequesterMono(payload, this); + } else { + return new SlowFireAndForgetRequesterMono(payload, this); + } + } + + @Override + public Mono requestResponse(Payload payload) { + return new RequestResponseRequesterMono(payload, this); + } + + @Override + public Flux requestStream(Payload payload) { + return new RequestStreamRequesterFlux(payload, this); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new RequestChannelRequesterFlux(payloads, this); + } + + @Override + public Mono metadataPush(Payload payload) { + Throwable terminationError = this.terminationError; + if (terminationError != null) { + payload.release(); + return Mono.error(terminationError); + } + + return new MetadataPushRequesterMono(payload, this); + } + + @Override + public RequesterLeaseTracker getRequesterLeaseTracker() { + return this.requesterLeaseTracker; + } + + @Override + public int getNextStreamId() { + int nextStreamId = super.getNextStreamId(); + + Throwable terminationError = this.terminationError; + if (terminationError != null) { + throw reactor.core.Exceptions.propagate(terminationError); + } + + return nextStreamId; + } + + @Override + public int addAndGetNextStreamId(FrameHandler frameHandler) { + int nextStreamId = super.addAndGetNextStreamId(frameHandler); + + Throwable terminationError = this.terminationError; + if (terminationError != null) { + super.remove(nextStreamId, frameHandler); + throw reactor.core.Exceptions.propagate(terminationError); + } + + return nextStreamId; + } + + @Override + public double availability() { + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + if (requesterLeaseTracker != null) { + return Math.min(getDuplexConnection().availability(), requesterLeaseTracker.availability()); + } else { + return getDuplexConnection().availability(); + } + } + + @Override + public void dispose() { + if (terminationError != null) { + return; + } + + getDuplexConnection().sendErrorAndClose(new ConnectionErrorException("Disposed")); + } + + @Override + public boolean isDisposed() { + return terminationError != null; + } + + @Override + public Mono onClose() { + return onAllClosed; + } + + private void handleIncomingFrames(ByteBuf frame) { + try { + int streamId = FrameHeaderCodec.streamId(frame); + FrameType type = FrameHeaderCodec.frameType(frame); + if (streamId == 0) { + handleStreamZero(type, frame); + } else { + handleFrame(streamId, type, frame); + } + } catch (Throwable t) { + LOGGER.error("Unexpected error during frame handling", t); + final ConnectionErrorException error = + new ConnectionErrorException("Unexpected error during frame handling", t); + getDuplexConnection().sendErrorAndClose(error); + } + } + + private void handleStreamZero(FrameType type, ByteBuf frame) { + switch (type) { + case ERROR: + tryTerminateOnZeroError(frame); + break; + case LEASE: + requesterLeaseTracker.handleLeaseFrame(frame); + break; + case KEEPALIVE: + if (keepAliveFramesAcceptor != null) { + keepAliveFramesAcceptor.receive(frame); + } + break; + default: + // Ignore unknown frames. Throwing an error will close the socket. + if (LOGGER.isInfoEnabled()) { + LOGGER.info("Requester received unsupported frame on stream 0: " + frame.toString()); + } + } + } + + private void handleFrame(int streamId, FrameType type, ByteBuf frame) { + FrameHandler receiver = this.get(streamId); + if (receiver == null) { + handleMissingResponseProcessor(streamId, type, frame); + return; + } + + switch (type) { + case NEXT_COMPLETE: + receiver.handleNext(frame, false, true); + break; + case NEXT: + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); + receiver.handleNext(frame, hasFollows, false); + break; + case COMPLETE: + receiver.handleComplete(); + break; + case ERROR: + receiver.handleError(Exceptions.from(streamId, frame)); + break; + case CANCEL: + receiver.handleCancel(); + break; + case REQUEST_N: + long n = RequestNFrameCodec.requestN(frame); + receiver.handleRequestN(n); + break; + default: + throw new IllegalStateException( + "Requester received unsupported frame on stream " + streamId + ": " + frame.toString()); + } + } + + @SuppressWarnings("ConstantConditions") + private void handleMissingResponseProcessor(int streamId, FrameType type, ByteBuf frame) { + if (!super.streamIdSupplier.isBeforeOrCurrent(streamId)) { + if (type == FrameType.ERROR) { + // message for stream that has never existed, we have a problem with + // the overall connection and must tear down + String errorMessage = ErrorFrameCodec.dataUtf8(frame); + + throw new IllegalStateException( + "Client received error for non-existent stream: " + + streamId + + " Message: " + + errorMessage); + } else { + throw new IllegalStateException( + "Client received message for non-existent stream: " + + streamId + + ", frame type: " + + type); + } + } + // receiving a frame after a given stream has been cancelled/completed, + // so ignore (cancellation is async so there is a race condition) + } + + private void tryTerminateOnKeepAlive(KeepAliveSupport.KeepAlive keepAlive) { + tryTerminate( + () -> + new ConnectionErrorException( + String.format("No keep-alive acks for %d ms", keepAlive.getTimeout().toMillis()))); + getDuplexConnection().dispose(); + } + + private void tryShutdown(Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } + if (terminationError == null) { + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + terminate(CLOSED_CHANNEL_EXCEPTION); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.info( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } + + private void tryTerminateOnZeroError(ByteBuf errorFrame) { + tryTerminate(() -> Exceptions.from(0, errorFrame)); + } + + private void tryTerminate(Supplier errorSupplier) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } + if (terminationError == null) { + Throwable e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + terminate(e); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } + + private void tryShutdown() { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("trying to close requester " + getDuplexConnection()); + } + if (terminationError == null) { + if (TERMINATION_ERROR.compareAndSet(this, null, CLOSED_CHANNEL_EXCEPTION)) { + terminate(CLOSED_CHANNEL_EXCEPTION); + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } else { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "trying to close requester failed because of " + + terminationError + + " " + + getDuplexConnection()); + } + } + } + + private void terminate(Throwable e) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("closing requester " + getDuplexConnection() + " due to " + e); + } + if (keepAliveFramesAcceptor != null) { + keepAliveFramesAcceptor.dispose(); + } + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + if (requesterLeaseTracker != null) { + requesterLeaseTracker.dispose(e); + } + + final Collection activeStreamsCopy; + synchronized (this) { + final IntObjectMap activeStreams = this.activeStreams; + activeStreamsCopy = new ArrayList<>(activeStreams.values()); + } + + for (FrameHandler handler : activeStreamsCopy) { + if (handler != null) { + try { + handler.handleError(e); + } catch (Throwable ignored) { + } + } + } + + if (e == CLOSED_CHANNEL_EXCEPTION) { + onThisSideClosedSink.tryEmitEmpty(); + } else { + onThisSideClosedSink.tryEmitError(e); + } + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("requester closed " + getDuplexConnection()); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java new file mode 100644 index 000000000..50c5ba54c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketResponder.java @@ -0,0 +1,477 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +/** Responder side of RSocket. Receives {@link ByteBuf}s from a peer's {@link RSocketRequester} */ +class RSocketResponder extends RequesterResponderSupport implements RSocket { + + private static final Logger LOGGER = LoggerFactory.getLogger(RSocketResponder.class); + + private static final Exception CLOSED_CHANNEL_EXCEPTION = new ClosedChannelException(); + + private final RSocket requestHandler; + private final Sinks.Empty onThisSideClosedSink; + + @Nullable private final ResponderLeaseTracker leaseHandler; + + private volatile Throwable terminationError; + private static final AtomicReferenceFieldUpdater TERMINATION_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RSocketResponder.class, Throwable.class, "terminationError"); + + RSocketResponder( + DuplexConnection connection, + RSocket requestHandler, + PayloadDecoder payloadDecoder, + @Nullable ResponderLeaseTracker leaseHandler, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + Function requestInterceptorFunction, + Sinks.Empty onThisSideClosedSink) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + payloadDecoder, + connection, + null, + requestInterceptorFunction); + + this.requestHandler = requestHandler; + + this.leaseHandler = leaseHandler; + this.onThisSideClosedSink = onThisSideClosedSink; + + connection + .onClose() + .subscribe(null, this::tryTerminateOnConnectionError, this::tryTerminateOnConnectionClose); + + connection.receive().subscribe(this::handleFrame, e -> {}); + } + + private void tryTerminateOnConnectionError(Throwable e) { + if (LOGGER.isDebugEnabled()) { + + LOGGER.debug("Try terminate connection on responder side"); + } + tryTerminate(() -> e); + } + + private void tryTerminateOnConnectionClose() { + if (LOGGER.isDebugEnabled()) { + LOGGER.info("Try terminate connection on responder side"); + } + tryTerminate(() -> CLOSED_CHANNEL_EXCEPTION); + } + + private void tryTerminate(Supplier errorSupplier) { + if (terminationError == null) { + Throwable e = errorSupplier.get(); + if (TERMINATION_ERROR.compareAndSet(this, null, e)) { + doOnDispose(); + } + } + } + + @Override + public Mono fireAndForget(Payload payload) { + try { + return requestHandler.fireAndForget(payload); + } catch (Throwable t) { + return Mono.error(t); + } + } + + @Override + public Mono requestResponse(Payload payload) { + try { + return requestHandler.requestResponse(payload); + } catch (Throwable t) { + return Mono.error(t); + } + } + + @Override + public Flux requestStream(Payload payload) { + try { + return requestHandler.requestStream(payload); + } catch (Throwable t) { + return Flux.error(t); + } + } + + @Override + public Flux requestChannel(Publisher payloads) { + try { + return requestHandler.requestChannel(payloads); + } catch (Throwable t) { + return Flux.error(t); + } + } + + @Override + public Mono metadataPush(Payload payload) { + try { + return requestHandler.metadataPush(payload); + } catch (Throwable t) { + return Mono.error(t); + } + } + + @Override + public void dispose() { + tryTerminate(() -> new CancellationException("Disposed")); + } + + @Override + public boolean isDisposed() { + return getDuplexConnection().isDisposed(); + } + + @Override + public Mono onClose() { + return getDuplexConnection().onClose(); + } + + final void doOnDispose() { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("closing responder " + getDuplexConnection()); + } + cleanUpSendingSubscriptions(); + + getDuplexConnection().dispose(); + final RequestInterceptor requestInterceptor = getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.dispose(); + } + + final ResponderLeaseTracker handler = leaseHandler; + if (handler != null) { + handler.dispose(); + } + + requestHandler.dispose(); + onThisSideClosedSink.tryEmitEmpty(); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug("responder closed " + getDuplexConnection()); + } + } + + private void cleanUpSendingSubscriptions() { + final Collection activeStreamsCopy; + synchronized (this) { + final IntObjectMap activeStreams = this.activeStreams; + activeStreamsCopy = new ArrayList<>(activeStreams.values()); + } + + for (FrameHandler handler : activeStreamsCopy) { + if (handler != null) { + handler.handleCancel(); + } + } + } + + final void handleFrame(ByteBuf frame) { + try { + int streamId = FrameHeaderCodec.streamId(frame); + FrameHandler receiver; + FrameType frameType = FrameHeaderCodec.frameType(frame); + switch (frameType) { + case REQUEST_FNF: + handleFireAndForget(streamId, frame); + break; + case REQUEST_RESPONSE: + handleRequestResponse(streamId, frame); + break; + case REQUEST_STREAM: + long streamInitialRequestN = RequestStreamFrameCodec.initialRequestN(frame); + handleStream(streamId, frame, streamInitialRequestN); + break; + case REQUEST_CHANNEL: + long channelInitialRequestN = RequestChannelFrameCodec.initialRequestN(frame); + handleChannel( + streamId, frame, channelInitialRequestN, FrameHeaderCodec.hasComplete(frame)); + break; + case METADATA_PUSH: + handleMetadataPush(metadataPush(super.getPayloadDecoder().apply(frame))); + break; + case CANCEL: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleCancel(); + } + break; + case REQUEST_N: + receiver = super.get(streamId); + if (receiver != null) { + long n = RequestNFrameCodec.requestN(frame); + receiver.handleRequestN(n); + } + break; + case PAYLOAD: + // TODO: Hook in receiving socket. + break; + case NEXT: + receiver = super.get(streamId); + if (receiver != null) { + boolean hasFollows = FrameHeaderCodec.hasFollows(frame); + receiver.handleNext(frame, hasFollows, false); + } + break; + case COMPLETE: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleComplete(); + } + break; + case ERROR: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleError(io.rsocket.exceptions.Exceptions.from(streamId, frame)); + } + break; + case NEXT_COMPLETE: + receiver = super.get(streamId); + if (receiver != null) { + receiver.handleNext(frame, false, true); + } + break; + case SETUP: + getDuplexConnection() + .sendFrame( + streamId, + ErrorFrameCodec.encode( + super.getAllocator(), + streamId, + new IllegalStateException("Setup frame received post setup."))); + break; + case LEASE: + default: + getDuplexConnection() + .sendFrame( + streamId, + ErrorFrameCodec.encode( + super.getAllocator(), + streamId, + new IllegalStateException( + "ServerRSocket: Unexpected frame type: " + frameType))); + break; + } + } catch (Throwable t) { + LOGGER.error("Unexpected error during frame handling", t); + getDuplexConnection() + .sendFrame( + 0, + ErrorFrameCodec.encode( + super.getAllocator(), + 0, + new ConnectionErrorException("Unexpected error during frame handling", t))); + this.tryTerminateOnConnectionError(t); + } + } + + final void handleFireAndForget(int streamId, ByteBuf frame) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + if (FrameHeaderCodec.hasFollows(frame)) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } + + FireAndForgetResponderSubscriber subscriber = + new FireAndForgetResponderSubscriber(streamId, frame, this, this); + + this.add(streamId, subscriber); + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(new FireAndForgetResponderSubscriber(streamId, this)); + } else { + fireAndForget(super.getPayloadDecoder().apply(frame)) + .subscribe(FireAndForgetResponderSubscriber.INSTANCE); + } + } + } else { + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_FNF, RequestFireAndForgetFrameCodec.metadata(frame)); + } + } + } + + final void handleRequestResponse(int streamId, ByteBuf frame) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } + + if (FrameHeaderCodec.hasFollows(frame)) { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, frame, this, this); + + this.add(streamId, subscriber); + } else { + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, this); + + if (this.add(streamId, subscriber)) { + this.requestResponse(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_RESPONSE, RequestResponseFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); + } + } + + final void handleStream(int streamId, ByteBuf frame, long initialRequestN) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } + + if (FrameHeaderCodec.hasFollows(frame)) { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, frame, this, this); + + this.add(streamId, subscriber); + } else { + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, this); + + if (this.add(streamId, subscriber)) { + this.requestStream(super.getPayloadDecoder().apply(frame)).subscribe(subscriber); + } + } + } else { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onReject( + leaseError, FrameType.REQUEST_STREAM, RequestStreamFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); + } + } + + final void handleChannel(int streamId, ByteBuf frame, long initialRequestN, boolean complete) { + ResponderLeaseTracker leaseHandler = this.leaseHandler; + Throwable leaseError; + if (leaseHandler == null || (leaseError = leaseHandler.use()) == null) { + final RequestInterceptor requestInterceptor = this.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart( + streamId, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + + if (FrameHeaderCodec.hasFollows(frame)) { + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, frame, this, this); + + this.add(streamId, subscriber); + } else { + final Payload firstPayload = super.getPayloadDecoder().apply(frame); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber(streamId, initialRequestN, firstPayload, this); + + if (this.add(streamId, subscriber)) { + this.requestChannel(subscriber).subscribe(subscriber); + if (complete) { + subscriber.handleComplete(); + } + } + } + } else { + final RequestInterceptor requestTracker = this.getRequestInterceptor(); + if (requestTracker != null) { + requestTracker.onReject( + leaseError, FrameType.REQUEST_CHANNEL, RequestChannelFrameCodec.metadata(frame)); + } + sendLeaseRejection(streamId, leaseError); + } + } + + private void sendLeaseRejection(int streamId, Throwable leaseError) { + getDuplexConnection() + .sendFrame(streamId, ErrorFrameCodec.encode(getAllocator(), streamId, leaseError)); + } + + private void handleMetadataPush(Mono result) { + result.subscribe(MetadataPushResponderSubscriber.INSTANCE); + } + + @Override + public boolean add(int streamId, FrameHandler frameHandler) { + if (!super.add(streamId, frameHandler)) { + frameHandler.handleCancel(); + return false; + } + + return true; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java new file mode 100644 index 000000000..e969c39d2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RSocketServer.java @@ -0,0 +1,523 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.assertMtu; +import static io.rsocket.core.PayloadValidationUtils.assertValidateSetup; +import static io.rsocket.core.ReassemblyUtils.assertInboundPayloadSize; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RSocketErrorException; +import io.rsocket.SocketAcceptor; +import io.rsocket.exceptions.InvalidSetupException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.lease.TrackingLeaseSender; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.plugins.InterceptorRegistry; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.resume.SessionManager; +import io.rsocket.transport.ServerTransport; +import java.time.Duration; +import java.util.Objects; +import java.util.function.Consumer; +import java.util.function.Supplier; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * The main class for starting an RSocket server. + * + *

For example: + * + *

{@code
+ * CloseableChannel closeable =
+ *         RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+ *                 .bind(TcpServerTransport.create("localhost", 7000))
+ *                 .block();
+ * }
+ */ +public final class RSocketServer { + private static final String SERVER_TAG = "server"; + + private SocketAcceptor acceptor = SocketAcceptor.with(new RSocket() {}); + private InitializingInterceptorRegistry interceptors = new InitializingInterceptorRegistry(); + + private Resume resume; + private Consumer leaseConfigurer = null; + + private int mtu = 0; + private int maxInboundPayloadSize = Integer.MAX_VALUE; + private PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + private Duration timeout = Duration.ofMinutes(1); + + private RSocketServer() {} + + /** Static factory method to create an {@code RSocketServer}. */ + public static RSocketServer create() { + return new RSocketServer(); + } + + /** + * Static factory method to create an {@code RSocketServer} instance with the given {@code + * SocketAcceptor}. Effectively a shortcut for: + * + *
+   * RSocketServer.create().acceptor(...);
+   * 
+ * + * @param acceptor the acceptor to handle connections with + * @return the same instance for method chaining + * @see #acceptor(SocketAcceptor) + */ + public static RSocketServer create(SocketAcceptor acceptor) { + return RSocketServer.create().acceptor(acceptor); + } + + /** + * Set the acceptor to handle incoming connections and handle requests. + * + *

An example with access to the {@code SETUP} frame and sending RSocket for performing + * requests back to the client if needed: + * + *

{@code
+   * RSocketServer.create((setup, sendingRSocket) -> Mono.just(new RSocket() {...}))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

A shortcut to provide the handling RSocket only: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

A shortcut to handle request-response interactions only: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.forRequestResponse(payload -> ...))
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

By default, {@code new RSocket(){}} is used for handling which rejects requests from the + * client with {@link UnsupportedOperationException}. + * + * @param acceptor the acceptor to handle incoming connections and requests with + * @return the same instance for method chaining + */ + public RSocketServer acceptor(SocketAcceptor acceptor) { + Objects.requireNonNull(acceptor); + this.acceptor = acceptor; + return this; + } + + /** + * Configure interception at one of the following levels: + * + *

    + *
  • Transport level + *
  • At the level of accepting new connections + *
  • Performing requests + *
  • Responding to requests + *
+ * + * @param configurer a configurer to customize interception with. + * @return the same instance for method chaining + * @see io.rsocket.plugins.LimitRateInterceptor + */ + public RSocketServer interceptors(Consumer configurer) { + configurer.accept(this.interceptors); + return this; + } + + /** + * Enables the Resume capability of the RSocket protocol where if the client gets disconnected, + * the connection is re-acquired and any interrupted streams are transparently resumed. For this + * to work clients must also support and request to enable this when connecting. + * + *

Use the {@link Resume} argument to customize the Resume session duration, storage, retry + * logic, and others. + * + *

By default this is not enabled. + * + * @param resume configuration for the Resume capability + * @return the same instance for method chaining + * @see Resuming + * Operation + */ + public RSocketServer resume(Resume resume) { + this.resume = resume; + return this; + } + + /** + * Enables the Lease feature of the RSocket protocol where the number of requests that can be + * performed from either side are rationed via {@code LEASE} frames from the responder side. For + * this to work clients must also support and request to enable this when connecting. + * + *

Example usage: + * + *

{@code
+   * RSocketServer.create(SocketAcceptor.with(new RSocket() {...}))
+   *         .lease(spec ->
+   *            spec.sender(() -> Flux.interval(ofSeconds(1))
+   *                                  .map(__ -> Lease.create(ofSeconds(1), 1)))
+   *         )
+   *         .bind(TcpServerTransport.create("localhost", 7000))
+   *         .subscribe();
+   * }
+ * + *

By default this is not enabled. + * + * @param leaseConfigurer consumer which accepts {@link LeaseSpec} and use it for configuring + * @return the same instance for method chaining + * @see Lease + * Semantics + */ + public RSocketServer lease(Consumer leaseConfigurer) { + this.leaseConfigurer = leaseConfigurer; + return this; + } + + /** + * When this is set, frames reassembler control maximum payload size which can be reassembled. + * + *

By default this is not set in which case maximum reassembled payloads size is not + * controlled. + * + * @param maxInboundPayloadSize the threshold size for reassembly, must no be less than 64 bytes. + * Please note, {@code maxInboundPayloadSize} must always be greater or equal to {@link + * io.rsocket.transport.Transport#maxFrameLength()}, otherwise inbound frame can exceed the + * {@code maxInboundPayloadSize} + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketServer maxInboundPayloadSize(int maxInboundPayloadSize) { + this.maxInboundPayloadSize = assertInboundPayloadSize(maxInboundPayloadSize); + return this; + } + + /** + * Specify the max time to wait for the first frame (e.g. {@code SETUP}) on an accepted + * connection. + * + *

By default this is set to 1 minute. + * + * @param timeout duration + * @return the same instance for method chaining + */ + public RSocketServer maxTimeToFirstFrame(Duration timeout) { + if (timeout.isNegative() || timeout.isZero()) { + throw new IllegalArgumentException("Setup Handling Timeout should be greater than zero"); + } + this.timeout = timeout; + return this; + } + + /** + * When this is set, frames larger than the given maximum transmission unit (mtu) size value are + * fragmented. + * + *

By default this is not set in which case payloads are sent whole up to the maximum frame + * size of 16,777,215 bytes. + * + * @param mtu the threshold size for fragmentation, must be no less than 64 + * @return the same instance for method chaining + * @see Fragmentation + * and Reassembly + */ + public RSocketServer fragment(int mtu) { + this.mtu = assertMtu(mtu); + return this; + } + + /** + * Configure the {@code PayloadDecoder} used to create {@link Payload}'s from incoming raw frame + * buffers. The following decoders are available: + * + *

    + *
  • {@link PayloadDecoder#DEFAULT} -- the data and metadata are independent copies of the + * underlying frame {@link ByteBuf} + *
  • {@link PayloadDecoder#ZERO_COPY} -- the data and metadata are retained slices of the + * underlying {@link ByteBuf}. That's more efficient but requires careful tracking and + * {@link Payload#release() release} of the payload when no longer needed. + *
+ * + *

By default this is set to {@link PayloadDecoder#DEFAULT} in which case data and metadata are + * copied and do not need to be tracked and released. + * + * @param decoder the decoder to use + * @return the same instance for method chaining + */ + public RSocketServer payloadDecoder(PayloadDecoder decoder) { + Objects.requireNonNull(decoder); + this.payloadDecoder = decoder; + return this; + } + + /** + * Start the server on the given transport. + * + *

The following transports are available from additional RSocket Java modules: + * + *

    + *
  • {@link io.rsocket.transport.netty.client.TcpServerTransport TcpServerTransport} via + * {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.netty.client.WebsocketServerTransport + * WebsocketServerTransport} via {@code rsocket-transport-netty}. + *
  • {@link io.rsocket.transport.local.LocalServerTransport LocalServerTransport} via {@code + * rsocket-transport-local} + *
+ * + * @param transport the transport of choice to connect with + * @param the type of {@code Closeable} for the given transport + * @return a {@code Mono} with a {@code Closeable} that can be used to obtain information about + * the server, stop it, or be notified of when it is stopped. + */ + public Mono bind(ServerTransport transport) { + return Mono.defer( + new Supplier>() { + final ServerSetup serverSetup = serverSetup(timeout); + + @Override + public Mono get() { + int maxFrameLength = transport.maxFrameLength(); + assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); + return transport + .start(duplexConnection -> acceptor(serverSetup, duplexConnection, maxFrameLength)) + .doOnNext(c -> c.onClose().doFinally(v -> serverSetup.dispose()).subscribe()); + } + }); + } + + /** + * Start the server on the given transport. Effectively is a shortcut for {@code + * .bind(ServerTransport).block()} + */ + public T bindNow(ServerTransport transport) { + return bind(transport).block(); + } + /** + * An alternative to {@link #bind(ServerTransport)} that is useful for installing RSocket on a + * server that is started independently. + * + * @see io.rsocket.examples.transport.ws.WebSocketHeadersSample + */ + public ServerTransport.ConnectionAcceptor asConnectionAcceptor() { + return asConnectionAcceptor(FRAME_LENGTH_MASK); + } + + /** + * An alternative to {@link #bind(ServerTransport)} that is useful for installing RSocket on a + * server that is started independently. + * + * @see io.rsocket.examples.transport.ws.WebSocketHeadersSample + */ + public ServerTransport.ConnectionAcceptor asConnectionAcceptor(int maxFrameLength) { + assertValidateSetup(maxFrameLength, maxInboundPayloadSize, mtu); + return new ServerTransport.ConnectionAcceptor() { + private final ServerSetup serverSetup = serverSetup(timeout); + + @Override + public Mono apply(DuplexConnection connection) { + return acceptor(serverSetup, connection, maxFrameLength); + } + }; + } + + private Mono acceptor( + ServerSetup serverSetup, DuplexConnection sourceConnection, int maxFrameLength) { + + final DuplexConnection interceptedConnection = + interceptors.initConnection(DuplexConnectionInterceptor.Type.SOURCE, sourceConnection); + + return serverSetup + .init(LoggingDuplexConnection.wrapIfEnabled(interceptedConnection)) + .flatMap( + tuple2 -> { + final ByteBuf startFrame = tuple2.getT1(); + final DuplexConnection clientServerConnection = tuple2.getT2(); + + return accept(serverSetup, startFrame, clientServerConnection, maxFrameLength); + }); + } + + private Mono acceptResume( + ServerSetup serverSetup, ByteBuf resumeFrame, DuplexConnection clientServerConnection) { + return serverSetup.acceptRSocketResume(resumeFrame, clientServerConnection); + } + + private Mono accept( + ServerSetup serverSetup, + ByteBuf startFrame, + DuplexConnection clientServerConnection, + int maxFrameLength) { + switch (FrameHeaderCodec.frameType(startFrame)) { + case SETUP: + return acceptSetup(serverSetup, startFrame, clientServerConnection, maxFrameLength); + case RESUME: + return acceptResume(serverSetup, startFrame, clientServerConnection); + default: + serverSetup.sendError( + clientServerConnection, + new InvalidSetupException("SETUP or RESUME frame must be received before any others")); + return clientServerConnection.onClose(); + } + } + + private Mono acceptSetup( + ServerSetup serverSetup, + ByteBuf setupFrame, + DuplexConnection clientServerConnection, + int maxFrameLength) { + + if (!SetupFrameCodec.isSupportedVersion(setupFrame)) { + serverSetup.sendError( + clientServerConnection, + new InvalidSetupException( + "Unsupported version: " + SetupFrameCodec.humanReadableVersion(setupFrame))); + return clientServerConnection.onClose(); + } + + boolean leaseEnabled = leaseConfigurer != null; + if (SetupFrameCodec.honorLease(setupFrame) && !leaseEnabled) { + serverSetup.sendError( + clientServerConnection, new InvalidSetupException("lease is not supported")); + return clientServerConnection.onClose(); + } + + return serverSetup.acceptRSocketSetup( + setupFrame, + clientServerConnection, + (keepAliveHandler, wrappedDuplexConnection) -> { + ConnectionSetupPayload setupPayload = + new DefaultConnectionSetupPayload(setupFrame.retain()); + final InitializingInterceptorRegistry interceptors = this.interceptors; + final ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(wrappedDuplexConnection, interceptors, false); + + final LeaseSpec leases; + final RequesterLeaseTracker requesterLeaseTracker; + if (leaseEnabled) { + leases = new LeaseSpec(); + leaseConfigurer.accept(leases); + requesterLeaseTracker = + new RequesterLeaseTracker(SERVER_TAG, leases.maxPendingRequests); + } else { + leases = null; + requesterLeaseTracker = null; + } + + final Sinks.Empty requesterOnAllClosedSink = Sinks.unsafe().empty(); + final Sinks.Empty responderOnAllClosedSink = Sinks.unsafe().empty(); + + RSocket rSocketRequester = + new RSocketRequester( + multiplexer.asServerConnection(), + payloadDecoder, + StreamIdSupplier.serverSupplier(), + mtu, + maxFrameLength, + maxInboundPayloadSize, + setupPayload.keepAliveInterval(), + setupPayload.keepAliveMaxLifetime(), + keepAliveHandler, + interceptors::initRequesterRequestInterceptor, + requesterLeaseTracker, + requesterOnAllClosedSink, + Mono.whenDelayError( + responderOnAllClosedSink.asMono(), requesterOnAllClosedSink.asMono())); + + RSocket wrappedRSocketRequester = interceptors.initRequester(rSocketRequester); + + return interceptors + .initSocketAcceptor(acceptor) + .accept(setupPayload, wrappedRSocketRequester) + .onErrorResume( + err -> + Mono.fromRunnable( + () -> + serverSetup.sendError( + wrappedDuplexConnection, rejectedSetupError(err))) + .then(wrappedDuplexConnection.onClose()) + .then(Mono.error(err))) + .doOnNext( + rSocketHandler -> { + RSocket wrappedRSocketHandler = interceptors.initResponder(rSocketHandler); + DuplexConnection clientConnection = multiplexer.asClientConnection(); + + ResponderLeaseTracker responderLeaseTracker = + leaseEnabled + ? new ResponderLeaseTracker(SERVER_TAG, clientConnection, leases.sender) + : null; + + RSocket rSocketResponder = + new RSocketResponder( + clientConnection, + wrappedRSocketHandler, + payloadDecoder, + responderLeaseTracker, + mtu, + maxFrameLength, + maxInboundPayloadSize, + leaseEnabled && leases.sender instanceof TrackingLeaseSender + ? rSocket -> + interceptors.initResponderRequestInterceptor( + rSocket, (RequestInterceptor) leases.sender) + : interceptors::initResponderRequestInterceptor, + responderOnAllClosedSink); + }) + .doFinally(signalType -> setupPayload.release()) + .then(); + }); + } + + private ServerSetup serverSetup(Duration timeout) { + return resume != null ? createSetup(timeout) : new ServerSetup.DefaultServerSetup(timeout); + } + + ServerSetup createSetup(Duration timeout) { + return new ServerSetup.ResumableServerSetup( + timeout, + new SessionManager(), + resume.getSessionDuration(), + resume.getStreamTimeout(), + resume.getStoreFactory(SERVER_TAG), + resume.isCleanupStoreOnKeepAlive()); + } + + private RSocketErrorException rejectedSetupError(Throwable err) { + String msg = err.getMessage(); + return new RejectedSetupException(msg == null ? "rejected by server acceptor" : msg); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java b/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java new file mode 100644 index 000000000..8e084fe9c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ReassemblyUtils.java @@ -0,0 +1,247 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.MIN_MTU_SIZE; +import static io.rsocket.core.StateUtils.isReassembling; +import static io.rsocket.core.StateUtils.isTerminated; +import static io.rsocket.core.StateUtils.markReassembled; +import static io.rsocket.core.StateUtils.markReassembling; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; + +class ReassemblyUtils { + static final String ILLEGAL_REASSEMBLED_PAYLOAD_SIZE = + "Reassembled payload size went out of allowed %s bytes"; + + @SuppressWarnings("ConstantConditions") + static void release(RequesterFrameHandler framesHolder, long state) { + if (isReassembling(state)) { + final CompositeByteBuf frames = framesHolder.getFrames(); + framesHolder.setFrames(null); + frames.release(); + } + } + + @SuppressWarnings({"ConstantConditions", "SynchronizationOnLocalVariableOrMethodParameter"}) + static void synchronizedRelease(RequesterFrameHandler framesHolder, long state) { + if (isReassembling(state)) { + final CompositeByteBuf frames = framesHolder.getFrames(); + framesHolder.setFrames(null); + + synchronized (frames) { + frames.release(); + } + } + } + + static void handleNextSupport( + AtomicLongFieldUpdater updater, + T instance, + Subscription subscription, + CoreSubscriber inboundSubscriber, + PayloadDecoder payloadDecoder, + ByteBufAllocator allocator, + int maxInboundPayloadSize, + ByteBuf frame, + boolean hasFollows, + boolean isLastPayload) { + + long state = updater.get(instance); + if (isTerminated(state)) { + return; + } + + if (!hasFollows && !isReassembling(state)) { + Payload payload; + try { + payload = payloadDecoder.apply(frame); + } catch (Throwable t) { + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + + instance.handlePayload(payload); + if (isLastPayload) { + instance.handleComplete(); + } + return; + } + + CompositeByteBuf frames = instance.getFrames(); + if (frames == null) { + frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), frame, hasFollows, maxInboundPayloadSize); + instance.setFrames(frames); + + long previousState = markReassembling(updater, instance); + if (isTerminated(previousState)) { + instance.setFrames(null); + frames.release(); + return; + } + } else { + try { + frames = + ReassemblyUtils.addFollowingFrame(frames, frame, hasFollows, maxInboundPayloadSize); + } catch (IllegalStateException t) { + if (isTerminated(updater.get(instance))) { + return; + } + + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + } + + if (!hasFollows) { + long previousState = markReassembled(updater, instance); + if (isTerminated(previousState)) { + return; + } + + instance.setFrames(null); + + Payload payload; + try { + payload = payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + // sends cancel frame to prevent any further frames + subscription.cancel(); + // terminates downstream + inboundSubscriber.onError(t); + + return; + } + + instance.handlePayload(payload); + + if (isLastPayload) { + instance.handleComplete(); + } + } + } + + static CompositeByteBuf addFollowingFrame( + CompositeByteBuf frames, + ByteBuf followingFrame, + boolean hasFollows, + int maxInboundPayloadSize) { + int readableBytes = frames.readableBytes(); + if (readableBytes == 0) { + return frames.addComponent(true, followingFrame.retain()); + } else if (maxInboundPayloadSize != Integer.MAX_VALUE + && readableBytes + followingFrame.readableBytes() - FrameHeaderCodec.size() + > maxInboundPayloadSize) { + throw new IllegalStateException( + String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)); + } else if (followingFrame.readableBytes() < MIN_MTU_SIZE - 3 && hasFollows) { + // FIXME: check MIN_MTU_SIZE only (currently fragments have size of 61) + throw new IllegalStateException("Fragment is too small."); + } + + final boolean hasMetadata = FrameHeaderCodec.hasMetadata(followingFrame); + + // skip headers + followingFrame.skipBytes(FrameHeaderCodec.size()); + + // if has metadata, then we have to increase metadata length in containing frames + // CompositeByteBuf + if (hasMetadata) { + final FrameType frameType = FrameHeaderCodec.frameType(frames); + final int lengthFieldPosition = + FrameHeaderCodec.size() + (frameType.hasInitialRequestN() ? Integer.BYTES : 0); + + frames.markReaderIndex(); + frames.skipBytes(lengthFieldPosition); + + final int nextMetadataLength = decodeLength(frames) + decodeLength(followingFrame); + + frames.resetReaderIndex(); + + frames.markWriterIndex(); + frames.writerIndex(lengthFieldPosition); + + encodeLength(frames, nextMetadataLength); + + frames.resetWriterIndex(); + } + + synchronized (frames) { + if (frames.refCnt() > 0) { + followingFrame.retain(); + return frames.addComponent(true, followingFrame); + } else { + throw new IllegalReferenceCountException(0); + } + } + } + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + int length = (byteBuf.readByte() & 0xFF) << 16; + length |= (byteBuf.readByte() & 0xFF) << 8; + length |= byteBuf.readByte() & 0xFF; + return length; + } + + static int assertInboundPayloadSize(int inboundPayloadSize) { + if (inboundPayloadSize < MIN_MTU_SIZE) { + String msg = + String.format( + "The min allowed inboundPayloadSize size is %d bytes, provided: %d", + FrameLengthCodec.FRAME_LENGTH_MASK, inboundPayloadSize); + throw new IllegalArgumentException(msg); + } else { + return inboundPayloadSize; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java new file mode 100644 index 000000000..afad6e0df --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ReconnectMono.java @@ -0,0 +1,275 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class ReconnectMono extends Mono implements Invalidatable, Disposable, Scannable { + + final Mono source; + final BiConsumer onValueReceived; + final Consumer onValueExpired; + final ResolvingInner resolvingInner; + + ReconnectMono( + Mono source, + Consumer onValueExpired, + BiConsumer onValueReceived) { + this.source = source; + this.onValueExpired = onValueExpired; + this.onValueReceived = onValueReceived; + this.resolvingInner = new ResolvingInner<>(this); + } + + public Mono getSource() { + return source; + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return source; + if (key == Attr.PREFETCH) return Integer.MAX_VALUE; + + final boolean isDisposed = isDisposed(); + if (key == Attr.TERMINATED) return isDisposed; + if (key == Attr.ERROR) return this.resolvingInner.t; + + return null; + } + + @Override + public void invalidate() { + this.resolvingInner.invalidate(); + } + + @Override + public void dispose() { + this.resolvingInner.terminate( + new CancellationException("ReconnectMono has already been disposed")); + } + + @Override + public boolean isDisposed() { + return this.resolvingInner.isDisposed(); + } + + @Override + @SuppressWarnings("uncheked") + public void subscribe(CoreSubscriber actual) { + final ResolvingOperator.MonoDeferredResolutionOperator inner = + new ResolvingOperator.MonoDeferredResolutionOperator<>(this.resolvingInner, actual); + actual.onSubscribe(inner); + + this.resolvingInner.observe(inner); + } + + /** + * Block the calling thread indefinitely, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @return the value of this {@code ReconnectMono} + */ + @Override + @Nullable + public T block() { + return block(null); + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ReconnectMono} is completed with an error a RuntimeException that + * wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@code ReconnectMono} or {@code null} if the timeout is reached and + * the {@code ReconnectMono} has not completed + */ + @Override + @Nullable + @SuppressWarnings("uncheked") + public T block(@Nullable Duration timeout) { + return this.resolvingInner.block(timeout); + } + + /** + * Subscriber that subscribes to the source {@link Mono} to receive its value.
+ * Note that the source is not expected to complete empty, and if this happens, execution will + * terminate with an {@code IllegalStateException}. + */ + static final class ReconnectMainSubscriber implements CoreSubscriber { + + final ResolvingInner parent; + + volatile Subscription s; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + ReconnectMainSubscriber.class, Subscription.class, "s"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ReconnectMainSubscriber.class, "wip"); + + T value; + + ReconnectMainSubscriber(ResolvingInner parent) { + this.parent = parent; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final T value = this.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + this.doFinally(); + return; + } + + final ResolvingInner p = this.parent; + if (value == null) { + p.terminate(new IllegalStateException("Source completed empty")); + } else { + p.complete(value); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + this.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doFinally(); + // terminate upstream which means retryBackoff has exhausted + this.parent.terminate(t); + } + + @Override + public void onNext(T value) { + if (this.s == Operators.cancelledSubscription()) { + this.parent.doOnValueExpired(value); + return; + } + + this.value = value; + // volatile write and check on racing + this.doFinally(); + } + + void dispose() { + if (Operators.terminate(S, this)) { + this.doFinally(); + } + } + + final void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + if (value != null && this.s == Operators.cancelledSubscription()) { + this.value = null; + this.parent.doOnValueExpired(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + } + + static final class ResolvingInner extends ResolvingOperator implements Scannable { + + final ReconnectMono parent; + final ReconnectMainSubscriber mainSubscriber; + + ResolvingInner(ReconnectMono parent) { + this.parent = parent; + this.mainSubscriber = new ReconnectMainSubscriber<>(this); + } + + @Override + protected void doOnValueExpired(T value) { + this.parent.onValueExpired.accept(value); + } + + @Override + protected void doOnValueResolved(T value) { + this.parent.onValueReceived.accept(value, this.parent); + } + + @Override + protected void doOnDispose() { + this.mainSubscriber.dispose(); + } + + @Override + protected void doSubscribe() { + this.parent.source.subscribe(this.mainSubscriber); + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return this.parent; + return null; + } + } +} + +interface Invalidatable { + + void invalidate(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java new file mode 100644 index 000000000..aab491793 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelRequesterFlux.java @@ -0,0 +1,829 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.DISCARD_CONTEXT; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.Objects; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; +import reactor.util.context.ContextView; + +final class RequestChannelRequesterFlux extends Flux + implements RequesterFrameHandler, + LeasePermitHandler, + CoreSubscriber, + Subscription, + Scannable { + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final PayloadDecoder payloadDecoder; + + final Publisher payloadsPublisher; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestChannelRequesterFlux.class, "state"); + + int streamId; + + boolean isFirstSignal = true; + Payload firstPayload; + + Subscription outboundSubscription; + boolean outboundDone; + Throwable outboundError; + + Context cachedContext; + CoreSubscriber inboundSubscriber; + boolean inboundDone; + long requested; + long produced; + + CompositeByteBuf frames; + + RequestChannelRequesterFlux( + Publisher payloadsPublisher, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payloadsPublisher = payloadsPublisher; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("RequestChannelFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + + Operators.error(actual, e); + return; + } + + this.inboundSubscriber = actual; + this.payloadsPublisher.subscribe(this); + } + + @Override + public void onSubscribe(Subscription outboundSubscription) { + if (Operators.validate(this.outboundSubscription, outboundSubscription)) { + this.outboundSubscription = outboundSubscription; + this.inboundSubscriber.onSubscribe(this); + } + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + this.requested = Operators.addCap(this.requested, n); + + long previousState = addRequestN(STATE, this, n, this.requesterLeaseTracker == null); + if (isTerminated(previousState)) { + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(extractRequestN(previousState))) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); + } + return; + } + + // do first request + this.outboundSubscription.request(1); + } + + @Override + public void onNext(Payload p) { + if (this.outboundDone) { + p.release(); + return; + } + + if (this.isFirstSignal) { + this.isFirstSignal = false; + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + + if (leaseEnabled) { + this.firstPayload = p; + + final long previousState = markFirstPayloadReceived(STATE, this); + if (isTerminated(previousState)) { + this.firstPayload = null; + p.release(); + return; + } + + requesterLeaseTracker.issue(this); + } else { + final long state = this.state; + if (isTerminated(state)) { + p.release(); + return; + } + // TODO: check if source is Scalar | Callable | Mono + sendFirstPayload(p, extractRequestN(state), false); + } + } else { + sendFollowingPayload(p); + } + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + sendFirstPayload( + firstPayload, extractRequestN(previousState), isOutboundTerminated(previousState)); + return true; + } + + void sendFirstPayload(Payload firstPayload, long initialRequestN, boolean completed) { + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, firstPayload, true)) { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + firstPayload.release(); + + this.inboundDone = true; + this.inboundSubscriber.onError(e); + return; + } + } catch (IllegalReferenceCountException e) { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_CHANNEL, null); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(e); + return; + } + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + final long previousState = markTerminated(STATE, this); + + firstPayload.release(); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(ut); + + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_CHANNEL, firstPayload.metadata()); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_CHANNEL, + initialRequestN, + mtu, + firstPayload, + connection, + allocator, + completed); + } catch (Throwable t) { + final long previousState = markTerminated(STATE, this); + + firstPayload.release(); + + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + sm.remove(streamId, this); + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + // now, this can be terminated in case of the following scenarios: + // + // 1) SendFirst is called synchronously from onNext, thus we can have + // handleError called before we marked first frame sent, thus we may check if + // inboundDone flag is true and exit execution without any further actions: + if (this.inboundDone) { + return; + } + + sm.remove(streamId, this); + + // 2) SendFirst is called asynchronously on the connection event-loop. Thus, we + // need to check if outbound error is present. Note, we check outboundError since + // in the last scenario, cancellation may terminate the state and async + // onComplete may set outboundDone to true. Thus, we explicitly check for + // outboundError + final Throwable outboundError = this.outboundError; + if (outboundError != null) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, outboundError); + connection.sendFrame(streamId, errorFrame); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, outboundError); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(outboundError); + } else { + // 3) SendFirst is interleaving with cancel. Thus, we need to generate cancel + // frame + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_CHANNEL); + } + } + + return; + } + + if (!completed && isOutboundTerminated(previousState)) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + connection.sendFrame(streamId, completeFrame); + } + + if (isMaxAllowedRequestN(initialRequestN)) { + return; + } + + long requestN = extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + connection.sendFrame(streamId, requestNFrame); + return; + } + + if (requestN > initialRequestN) { + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); + connection.sendFrame(streamId, requestNFrame); + } + } + + final void sendFollowingPayload(Payload followingPayload) { + int streamId = this.streamId; + int mtu = this.mtu; + + try { + if (!isValid(mtu, this.maxFrameLength, followingPayload, true)) { + followingPayload.release(); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.propagateErrorSafely(e); + return; + } + } catch (IllegalReferenceCountException e) { + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.propagateErrorSafely(e); + + return; + } + + try { + sendReleasingPayload( + streamId, + + // TODO: Should be a different flag in case of the scalar + // source or if we know in advance upstream is mono + FrameType.NEXT, + mtu, + followingPayload, + this.connection, + allocator, + true); + } catch (Throwable e) { + if (!this.tryCancel()) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.propagateErrorSafely(e); + } + } + + void propagateErrorSafely(Throwable t) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + if (!this.inboundDone) { + synchronized (this) { + if (!this.inboundDone) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + + @Override + public final void cancel() { + if (!tryCancel()) { + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + } + + boolean tryCancel() { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return false; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + if (!isReadyToSendFirstFrame(previousState) && isFirstPayloadReceived(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + // no need to send anything, since we have not started a stream yet (no logical wire) + return false; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final boolean firstFrameSent = isFirstFrameSent(previousState); + if (firstFrameSent) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); + this.connection.sendFrame(streamId, cancelFrame); + } + + return firstFrameSent; + } + + @Override + public void onError(Throwable t) { + if (this.outboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundError = t; + this.outboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + if (this.isFirstSignal) { + this.inboundDone = true; + this.inboundSubscriber.onError(t); + return; + } else if (!isReadyToSendFirstFrame(previousState)) { + // first signal is received but we are still waiting for lease permit to be issued, + // thus, just propagates error to actual subscriber + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + firstPayload.release(); + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + + return; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + // propagates error to remote responder + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + if (!isInboundTerminated(previousState)) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + synchronized (this) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + + this.inboundDone = true; + this.inboundSubscriber.onError(t); + } + } else { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + } + } + + @Override + public void onComplete() { + if (this.outboundDone) { + return; + } + + this.outboundDone = true; + + long previousState = markOutboundTerminated(STATE, this, true); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + return; + } + + if (!isFirstFrameSent(previousState)) { + if (!isFirstPayloadReceived(previousState)) { + // first signal, thus, just propagates error to actual subscriber + this.inboundSubscriber.onError(new CancellationException("Empty Source")); + } + return; + } + + final int streamId = this.streamId; + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + + this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated(previousState)) { + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public final void handleComplete() { + if (this.inboundDone) { + return; + } + + this.inboundDone = true; + + long previousState = markInboundTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isOutboundTerminated(previousState)) { + this.requesterResponderSupport.remove(this.streamId, this); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } + + this.inboundSubscriber.onComplete(); + } + + @Override + public final void handlePermitError(Throwable cause) { + this.inboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + final Payload p = this.firstPayload; + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onReject(cause, FrameType.REQUEST_CHANNEL, p.metadata()); + } + p.release(); + + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handleError(Throwable cause) { + if (this.inboundDone) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + this.inboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + if (!isOutboundTerminated(previousState)) { + this.outboundSubscription.cancel(); + } + + ReassemblyUtils.release(this, previousState); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause); + } + + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handlePayload(Payload value) { + synchronized (this) { + if (this.inboundDone) { + value.release(); + return; + } + + final long produced = this.produced; + if (this.requested == produced) { + value.release(); + if (!tryCancel()) { + return; + } + + final Throwable cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, cause); + } + + this.inboundSubscriber.onError(cause); + return; + } + + this.produced = produced + 1; + + this.inboundSubscriber.onNext(value); + } + } + + @Override + public void handleRequestN(long n) { + this.outboundSubscription.request(n); + } + + @Override + public void handleCancel() { + if (this.outboundDone) { + return; + } + + long previousState = markOutboundTerminated(STATE, this, false); + if (isTerminated(previousState) || isOutboundTerminated(previousState)) { + return; + } + + final boolean inboundTerminated = isInboundTerminated(previousState); + if (inboundTerminated) { + this.requesterResponderSupport.remove(this.streamId, this); + } + + this.outboundSubscription.cancel(); + + if (inboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.inboundSubscriber, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + @NonNull + public Context currentContext() { + long state = this.state; + + if (isSubscribedOrTerminated(state)) { + Context cachedContext = this.cachedContext; + if (cachedContext == null) { + cachedContext = + this.inboundSubscriber.currentContext().putAll((ContextView) DISCARD_CONTEXT); + this.cachedContext = cachedContext; + } + return cachedContext; + } + + return Context.empty(); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.REQUESTED_FROM_DOWNSTREAM) return state; + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestChannelFlux)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java new file mode 100644 index 000000000..32128fee4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestChannelResponderSubscriber.java @@ -0,0 +1,922 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; +import static reactor.core.Exceptions.TERMINATED; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestChannelResponderSubscriber extends Flux + implements ResponderFrameHandler, Subscription, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestChannelResponderSubscriber.class); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final long firstRequest; + + @Nullable final RequestInterceptor requestInterceptor; + + final RSocket handler; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestChannelResponderSubscriber.class, "state"); + + Payload firstPayload; + + Subscription outboundSubscription; + CoreSubscriber inboundSubscriber; + + CompositeByteBuf frames; + + volatile Throwable inboundError; + static final AtomicReferenceFieldUpdater + INBOUND_ERROR = + AtomicReferenceFieldUpdater.newUpdater( + RequestChannelResponderSubscriber.class, Throwable.class, "inboundError"); + + boolean inboundDone; + boolean outboundDone; + long requested; + long produced; + + public RequestChannelResponderSubscriber( + int streamId, + long firstRequestN, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.handler = handler; + this.firstRequest = firstRequestN; + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + STATE.lazySet(this, REASSEMBLING_FLAG); + } + + public RequestChannelResponderSubscriber( + int streamId, + long firstRequestN, + Payload firstPayload, + RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.firstRequest = firstRequestN; + this.firstPayload = firstPayload; + + this.handler = null; + this.frames = null; + } + + @Override + // subscriber from the requestChannel method + public void subscribe(CoreSubscriber actual) { + + long previousState = markSubscribed(STATE, this); + if (isTerminated(previousState)) { + Throwable t = Exceptions.terminate(INBOUND_ERROR, this); + if (t != TERMINATED) { + //noinspection ConstantConditions + Operators.error(actual, t); + } else { + Operators.error( + actual, + new CancellationException("RequestChannelSubscriber has already been terminated")); + } + return; + } + + if (isSubscribed(previousState)) { + Operators.error( + actual, new IllegalStateException("RequestChannelSubscriber allows only one Subscriber")); + return; + } + + this.inboundSubscriber = actual; + // sends sender as a subscription since every request|cancel signal should be encoded to + // requestNFrame|cancelFrame + actual.onSubscribe(this); + } + + @Override + // subscription to the outbound + public void onSubscribe(Subscription outboundSubscription) { + if (Operators.validate(this.outboundSubscription, outboundSubscription)) { + this.outboundSubscription = outboundSubscription; + outboundSubscription.request(this.firstRequest); + } + } + + @Override + public void request(long n) { + if (!Operators.validate(n)) { + return; + } + + this.requested = Operators.addCap(this.requested, n); + + long previousState = StateUtils.addRequestN(STATE, this, n); + if (isTerminated(previousState)) { + // full termination can be the result of both sides completion / cancelFrame / remote or local + // error + // therefore, we need to check inbound error value, to see what should be done + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError == TERMINATED) { + // means inbound was already terminated + return; + } + + if (inboundError != null || this.inboundDone) { + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + Payload firstPayload = this.firstPayload; + if (firstPayload != null) { + this.firstPayload = null; + + this.produced++; + + inboundSubscriber.onNext(firstPayload); + } + + if (inboundError != null) { + inboundSubscriber.onError(inboundError); + } else { + inboundSubscriber.onComplete(); + } + } + return; + } + + if (isInboundTerminated(previousState)) { + // inbound only can be terminated in case of cancellation or complete frame + if (!hasRequested(previousState) && !isFirstFrameSent(previousState) && this.inboundDone) { + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + this.produced++; + + inboundSubscriber.onNext(firstPayload); + inboundSubscriber.onComplete(); + + markFirstFrameSent(STATE, this); + } + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(StateUtils.extractRequestN(previousState))) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); + } + return; + } + + final CoreSubscriber inboundSubscriber = this.inboundSubscriber; + + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + + this.produced++; + + inboundSubscriber.onNext(firstPayload); + + previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + // full termination can be the result of both sides completion / cancelFrame / remote or local + // error + // therefore, we need to check inbound error value, to see what should be done + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError == TERMINATED) { + // means inbound was already terminated + return; + } + + if (inboundError != null) { + inboundSubscriber.onError(inboundError); + } else if (this.inboundDone) { + inboundSubscriber.onComplete(); + } + return; + } + + if (isInboundTerminated(previousState)) { + // inbound only can be terminated in case of cancellation or complete frame + if (this.inboundDone) { + inboundSubscriber.onComplete(); + } + return; + } + + long requestN = StateUtils.extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + this.connection.sendFrame(streamId, requestNFrame); + } else { + long firstRequestN = requestN - 1; + if (firstRequestN > 0) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(this.allocator, streamId, firstRequestN); + this.connection.sendFrame(streamId, requestNFrame); + } + } + } + + @Override + // inbound cancellation + public void cancel() { + long previousState = markInboundTerminated(STATE, this); + if (isTerminated(previousState) || isInboundTerminated(previousState)) { + INBOUND_ERROR.lazySet(this, TERMINATED); + return; + } + + if (!isFirstFrameSent(previousState) && !hasRequested(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } + + final int streamId = this.streamId; + + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { + this.requesterResponderSupport.remove(streamId, this); + } + + final ByteBuf cancelFrame = CancelFrameCodec.encode(this.allocator, streamId); + this.connection.sendFrame(streamId, cancelFrame); + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public final void handleCancel() { + Subscription outboundSubscription = this.outboundSubscription; + if (outboundSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + lazyTerminate(STATE, this); + + this.requesterResponderSupport.remove(this.streamId, this); + + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + frames.release(); + } else { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + return; + } + + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onCancel(this.streamId, FrameType.REQUEST_CHANNEL); + } + } + + final long tryTerminate(boolean isFromInbound) { + Exceptions.addThrowable( + INBOUND_ERROR, this, new CancellationException("Inbound has been canceled")); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return previousState; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + if (isFromInbound) { + frames.release(); + } else { + synchronized (frames) { + frames.release(); + } + } + } + + final Subscription outboundSubscription = this.outboundSubscription; + if (outboundSubscription == null) { + return previousState; + } + + outboundSubscription.cancel(); + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + if (isFromInbound) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } else { + synchronized (this) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + } + } + + return previousState; + } + + final void handlePayload(Payload p) { + synchronized (this) { + if (this.inboundDone) { + // payload from network so it has refCnt > 0 + p.release(); + return; + } + + final long produced = this.produced; + if (this.requested == produced) { + p.release(); + + this.inboundDone = true; + + final Throwable cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, cause); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + if (!wasThrowableAdded) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + } + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + this.connection.sendFrame( + streamId, + ErrorFrameCodec.encode( + this.allocator, streamId, new CanceledException(cause.getMessage()))); + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + + // this is downstream subscription so need to cancel it just in case error signal has not + // reached it + // needs for disconnected upstream and downstream case + this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, cause); + } + return; + } + + this.produced = produced + 1; + + this.inboundSubscriber.onNext(p); + } + } + + @Override + public final void handleError(Throwable t) { + if (this.inboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.inboundDone = true; + boolean wasThrowableAdded = Exceptions.addThrowable(INBOUND_ERROR, this, t); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + if (!wasThrowableAdded) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + } + return; + } + + this.requesterResponderSupport.remove(this.streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + frames.release(); + } + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (isFirstFrameSent(previousState) && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + + // this is downstream subscription so need to cancel it just in case error signal has not + // reached it + // needs for disconnected upstream and downstream case + this.outboundSubscription.cancel(); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + } + + @Override + public void handleComplete() { + if (this.inboundDone) { + return; + } + + this.inboundDone = true; + + long previousState = markInboundTerminated(STATE, this); + + final boolean isOutboundTerminated = isOutboundTerminated(previousState); + if (isOutboundTerminated) { + this.requesterResponderSupport.remove(this.streamId, this); + } + + if (isFirstFrameSent(previousState)) { + this.inboundSubscriber.onComplete(); + } + + if (isOutboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + long state = this.state; + if (isTerminated(state)) { + return; + } + + if (!hasFollows && !isReassembling(state)) { + Payload payload; + try { + payload = this.payloadDecoder.apply(frame); + } catch (Throwable t) { + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundDone = true; + // send error to terminate interaction + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode(this.allocator, streamId, new CanceledException(t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + return; + } + + this.handlePayload(payload); + if (isLastPayload) { + this.handleComplete(); + } + return; + } + + CompositeByteBuf frames = this.frames; + if (frames == null) { + frames = + ReassemblyUtils.addFollowingFrame( + this.allocator.compositeBuffer(), frame, hasFollows, this.maxInboundPayloadSize); + this.frames = frames; + + long previousState = markReassembling(STATE, this); + if (isTerminated(previousState)) { + this.frames = null; + frames.release(); + return; + } + } else { + try { + frames = + ReassemblyUtils.addFollowingFrame( + frames, frame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException e) { + if (isTerminated(this.state)) { + return; + } + + long previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + this.outboundDone = true; + // send error to terminate interaction + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + return; + } + } + + if (!hasFollows) { + long previousState = markReassembled(STATE, this); + if (isTerminated(previousState)) { + return; + } + + this.frames = null; + + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + ReferenceCountUtil.safeRelease(frames); + + previousState = this.tryTerminate(true); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(this.streamId, FrameType.REQUEST_CHANNEL, t); + } + + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + // send error to terminate interaction + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + + return; + } + + if (this.outboundSubscription == null) { + this.firstPayload = payload; + Flux source = this.handler.requestChannel(this); + source.subscribe(this); + } else { + this.handlePayload(payload); + } + + if (isLastPayload) { + this.handleComplete(); + } + } + } + + @Override + public void onNext(Payload p) { + if (this.outboundDone) { + ReferenceCountUtil.safeRelease(p); + return; + } + + final int streamId = this.streamId; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + p.release(); + + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + if (isTerminated(previousState)) { + Operators.onErrorDropped( + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)), + this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + return; + } + } catch (IllegalReferenceCountException e) { + + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + if (isTerminated(previousState)) { + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } else if (isOutboundTerminated(previousState)) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + + Operators.onErrorDropped(e, this.inboundSubscriber.currentContext()); + return; + } + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause:" + e.getMessage())); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, e); + } + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, connection, allocator, false); + } catch (Throwable t) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + long previousState = this.tryTerminate(false); + final RequestInterceptor interceptor = requestInterceptor; + if (interceptor != null && !isTerminated(previousState)) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + } + } + + @Override + public void onError(Throwable t) { + if (this.outboundDone) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + boolean wasThrowableAdded = + Exceptions.addThrowable( + INBOUND_ERROR, + this, + new CancellationException("Outbound has terminated with an error")); + this.outboundDone = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(t, this.inboundSubscriber.currentContext()); + return; + } + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + if (isReassembling(previousState)) { + final CompositeByteBuf frames = this.frames; + this.frames = null; + synchronized (frames) { + frames.release(); + } + } + + if (!isSubscribed(previousState)) { + final Payload firstPayload = this.firstPayload; + this.firstPayload = null; + firstPayload.release(); + } else if (wasThrowableAdded + && isFirstFrameSent(previousState) + && !isInboundTerminated(previousState)) { + Throwable inboundError = Exceptions.terminate(INBOUND_ERROR, this); + if (inboundError != TERMINATED) { + // FIXME: must be scheduled on the connection event-loop to achieve serial + // behaviour on the inbound subscriber + synchronized (this) { + this.inboundDone = true; + //noinspection ConstantConditions + this.inboundSubscriber.onError(inboundError); + } + } + } + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, t); + } + } + + @Override + public void onComplete() { + if (this.outboundDone) { + return; + } + + this.outboundDone = true; + + long previousState = markOutboundTerminated(STATE, this, false); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + + final boolean isInboundTerminated = isInboundTerminated(previousState); + if (isInboundTerminated) { + this.requesterResponderSupport.remove(streamId, this); + } + + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.connection.sendFrame(streamId, completeFrame); + + if (isInboundTerminated) { + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_CHANNEL, null); + } + } + } + + @Override + public final void handleRequestN(long n) { + this.outboundSubscription.request(n); + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java new file mode 100644 index 000000000..a13b105b5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseRequesterMono.java @@ -0,0 +1,400 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class RequestResponseRequesterMono extends Mono + implements RequesterFrameHandler, LeasePermitHandler, Subscription, Scannable { + + final ByteBufAllocator allocator; + final Payload payload; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final PayloadDecoder payloadDecoder; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestResponseRequesterMono.class, "state"); + + int streamId; + CoreSubscriber actual; + CompositeByteBuf frames; + boolean done; + + RequestResponseRequesterMono( + Payload payload, RequesterResponderSupport requesterResponderSupport) { + + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("RequestResponseMono allows only a single " + "Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_RESPONSE, null); + } + + Operators.error(actual, e); + return; + } + + this.actual = actual; + actual.onSubscribe(this); + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + final long previousState = addRequestN(STATE, this, n, !leaseEnabled); + + if (isTerminated(previousState) || hasRequested(previousState)) { + return; + } + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstPayload(this.payload); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstPayload(this.payload); + return true; + } + + void sendFirstPayload(Payload payload) { + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + this.done = true; + final long previousState = markTerminated(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + + payload.release(); + + if (!isTerminated(previousState)) { + this.actual.onError(ut); + } + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_RESPONSE, payload.metadata()); + } + + try { + sendReleasingPayload( + streamId, FrameType.REQUEST_RESPONSE, this.mtu, payload, connection, allocator, true); + } catch (Throwable e) { + this.done = true; + lazyTerminate(STATE, this); + + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + + this.actual.onError(e); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + if (this.done) { + return; + } + + sm.remove(streamId, this); + + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } + } + + @Override + public final void cancel() { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } else if (!isReadyToSendFirstFrame(previousState)) { + this.payload.release(); + } + } + + @Override + public final void handlePayload(Payload value) { + if (this.done) { + value.release(); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + value.release(); + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + + final CoreSubscriber a = this.actual; + a.onNext(value); + a.onComplete(); + } + + @Override + public final void handleComplete() { + if (this.done) { + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + + this.actual.onComplete(); + } + + @Override + public final void handlePermitError(Throwable cause) { + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_RESPONSE, p.metadata()); + } + p.release(); + + this.actual.onError(cause); + } + + @Override + public final void handleError(Throwable cause) { + if (this.done) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, cause); + } + + this.actual.onError(cause); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.actual, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.PREFETCH) return 0; + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestResponseMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java new file mode 100644 index 000000000..3d9d020ff --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestResponseResponderSubscriber.java @@ -0,0 +1,358 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestResponseResponderSubscriber + implements ResponderFrameHandler, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestResponseResponderSubscriber.class); + + final int streamId; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final RSocket handler; + + @Nullable final RequestInterceptor requestInterceptor; + + boolean done; + CompositeByteBuf frames; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + RequestResponseResponderSubscriber.class, Subscription.class, "s"); + + public RequestResponseResponderSubscriber( + int streamId, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.handler = handler; + + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + } + + public RequestResponseResponderSubscriber( + int streamId, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + + this.payloadDecoder = null; + this.handler = null; + this.frames = null; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (Operators.validate(this.s, subscription)) { + S.lazySet(this, subscription); + subscription.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(@Nullable Payload p) { + if (this.done) { + if (p != null) { + p.release(); + } + return; + } + + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription() + || !S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + if (p != null) { + p.release(); + } + return; + } + + this.done = true; + + final int streamId = this.streamId; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + this.requesterResponderSupport.remove(streamId, this); + + if (p == null) { + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(allocator, streamId); + connection.sendFrame(streamId, completeFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + return; + } + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + currentSubscription.cancel(); + + p.release(); + + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + return; + } + } catch (IllegalReferenceCountException e) { + currentSubscription.cancel(); + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause" + e.getMessage())); + connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, e); + } + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT_COMPLETE, mtu, p, connection, allocator, false); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, null); + } + } catch (Throwable t) { + currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + } + } + + @Override + public void onError(Throwable t) { + if (this.done) { + logger.debug("Dropped error", t); + return; + } + + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription() + || !S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + logger.debug("Dropped error", t); + return; + } + + this.done = true; + + final int streamId = this.streamId; + + this.requesterResponderSupport.remove(streamId, this); + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + } + + @Override + public void onComplete() { + onNext(null); + } + + @Override + public void handleCancel() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return; + } + + if (currentSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + frames.release(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + return; + } + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_RESPONSE); + } + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = this.frames; + if (frames == null) { + return; + } + + try { + ReassemblyUtils.addFollowingFrame(frames, frame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException t) { + S.lazySet(this, Operators.cancelledSubscription()); + + this.requesterResponderSupport.remove(this.streamId, this); + + this.frames = null; + frames.release(); + + logger.debug("Reassembly has failed", t); + + // sends error frame from the responder side to tell that something went wrong + final int streamId = this.streamId; + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + return; + } + + if (!hasFollows) { + this.frames = null; + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReferenceCountUtil.safeRelease(frames); + + logger.debug("Reassembly has failed", t); + + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_RESPONSE, t); + } + return; + } + + final Mono source = this.handler.requestResponse(payload); + source.subscribe(this); + } + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java new file mode 100644 index 000000000..6182ca506 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamRequesterFlux.java @@ -0,0 +1,449 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.ReassemblyUtils.handleNextSupport; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class RequestStreamRequesterFlux extends Flux + implements RequesterFrameHandler, LeasePermitHandler, Subscription, Scannable { + + final ByteBufAllocator allocator; + final Payload payload; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + final PayloadDecoder payloadDecoder; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(RequestStreamRequesterFlux.class, "state"); + + int streamId; + CoreSubscriber inboundSubscriber; + CompositeByteBuf frames; + boolean done; + long requested; + long produced; + + RequestStreamRequesterFlux(Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + long previousState = markSubscribed(STATE, this); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("RequestStreamFlux allows only a single Subscriber"); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + try { + if (!isValid(this.mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_STREAM, null); + } + + Operators.error(actual, e); + return; + } + + this.inboundSubscriber = actual; + actual.onSubscribe(this); + } + + @Override + public final void request(long n) { + if (!Operators.validate(n)) { + return; + } + + this.requested = Operators.addCap(this.requested, n); + + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + final long previousState = addRequestN(STATE, this, n, !leaseEnabled); + if (isTerminated(previousState)) { + return; + } + + if (hasRequested(previousState)) { + if (isFirstFrameSent(previousState) + && !isMaxAllowedRequestN(extractRequestN(previousState))) { + final int streamId = this.streamId; + final ByteBuf requestNFrame = RequestNFrameCodec.encode(this.allocator, streamId, n); + this.connection.sendFrame(streamId, requestNFrame); + } + return; + } + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstPayload(this.payload, n); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstPayload(this.payload, extractRequestN(previousState)); + return true; + } + + void sendFirstPayload(Payload payload, long initialRequestN) { + + final RequesterResponderSupport sm = this.requesterResponderSupport; + final DuplexConnection connection = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int streamId; + try { + streamId = sm.addAndGetNextStreamId(this); + this.streamId = streamId; + } catch (Throwable t) { + this.done = true; + final long previousState = markTerminated(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_STREAM, payload.metadata()); + } + + payload.release(); + + if (!isTerminated(previousState)) { + this.inboundSubscriber.onError(ut); + } + return; + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, FrameType.REQUEST_STREAM, payload.metadata()); + } + + try { + sendReleasingPayload( + streamId, + FrameType.REQUEST_STREAM, + initialRequestN, + this.mtu, + payload, + connection, + allocator, + false); + } catch (Throwable t) { + this.done = true; + lazyTerminate(STATE, this); + + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + + this.inboundSubscriber.onError(t); + return; + } + + long previousState = markFirstFrameSent(STATE, this); + if (isTerminated(previousState)) { + if (this.done) { + return; + } + + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + + sm.remove(streamId, this); + + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + return; + } + + if (isMaxAllowedRequestN(initialRequestN)) { + return; + } + + long requestN = extractRequestN(previousState); + if (isMaxAllowedRequestN(requestN)) { + final ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, streamId, requestN); + connection.sendFrame(streamId, requestNFrame); + return; + } + + if (requestN > initialRequestN) { + final ByteBuf requestNFrame = + RequestNFrameCodec.encode(allocator, streamId, requestN - initialRequestN); + connection.sendFrame(streamId, requestNFrame); + } + } + + @Override + public final void cancel() { + final long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + if (isFirstFrameSent(previousState)) { + final int streamId = this.streamId; + + ReassemblyUtils.synchronizedRelease(this, previousState); + + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + } else if (!isReadyToSendFirstFrame(previousState)) { + // no need to send anything, since the first request has not happened + this.payload.release(); + } + } + + @Override + public final void handlePayload(Payload p) { + if (this.done) { + p.release(); + return; + } + + final long produced = this.produced; + if (this.requested == produced) { + p.release(); + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + + final IllegalStateException cause = + Exceptions.failWithOverflow( + "The number of messages received exceeds the number requested"); + this.connection.sendFrame(streamId, CancelFrameCodec.encode(this.allocator, streamId)); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause); + } + + this.inboundSubscriber.onError(cause); + return; + } + + this.produced = produced + 1; + + this.inboundSubscriber.onNext(p); + } + + @Override + public final void handleComplete() { + if (this.done) { + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); + } + + this.inboundSubscriber.onComplete(); + } + + @Override + public final void handlePermitError(Throwable cause) { + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_STREAM, p.metadata()); + } + p.release(); + + this.inboundSubscriber.onError(cause); + } + + @Override + public final void handleError(Throwable cause) { + if (this.done) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + this.done = true; + + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.inboundSubscriber.currentContext()); + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + ReassemblyUtils.synchronizedRelease(this, previousState); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, cause); + } + + this.inboundSubscriber.onError(cause); + } + + @Override + public void handleNext(ByteBuf frame, boolean hasFollows, boolean isLastPayload) { + handleNextSupport( + STATE, + this, + this, + this.inboundSubscriber, + this.payloadDecoder, + this.allocator, + this.maxInboundPayloadSize, + frame, + hasFollows, + isLastPayload); + } + + @Override + public CompositeByteBuf getFrames() { + return this.frames; + } + + @Override + public void setFrames(CompositeByteBuf byteBuf) { + this.frames = byteBuf; + } + + @Override + @Nullable + public Object scanUnsafe(Attr key) { + // touch guard + long state = this.state; + + if (key == Attr.TERMINATED) return isTerminated(state); + if (key == Attr.REQUESTED_FROM_DOWNSTREAM) return extractRequestN(state); + + return null; + } + + @Override + @NonNull + public String stepName() { + return "source(RequestStreamFlux)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java new file mode 100644 index 000000000..48903ae38 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequestStreamResponderSubscriber.java @@ -0,0 +1,395 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +final class RequestStreamResponderSubscriber + implements ResponderFrameHandler, CoreSubscriber { + + static final Logger logger = LoggerFactory.getLogger(RequestStreamResponderSubscriber.class); + + final int streamId; + final long firstRequest; + final ByteBufAllocator allocator; + final PayloadDecoder payloadDecoder; + final int mtu; + final int maxFrameLength; + final int maxInboundPayloadSize; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequestInterceptor requestInterceptor; + + final RSocket handler; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + RequestStreamResponderSubscriber.class, Subscription.class, "s"); + + CompositeByteBuf frames; + boolean done; + + public RequestStreamResponderSubscriber( + int streamId, + long firstRequest, + ByteBuf firstFrame, + RequesterResponderSupport requesterResponderSupport, + RSocket handler) { + this.streamId = streamId; + this.firstRequest = firstRequest; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.payloadDecoder = requesterResponderSupport.getPayloadDecoder(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.handler = handler; + this.frames = + ReassemblyUtils.addFollowingFrame( + allocator.compositeBuffer(), firstFrame, true, maxInboundPayloadSize); + } + + public RequestStreamResponderSubscriber( + int streamId, long firstRequest, RequesterResponderSupport requesterResponderSupport) { + this.streamId = streamId; + this.firstRequest = firstRequest; + this.allocator = requesterResponderSupport.getAllocator(); + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.maxInboundPayloadSize = requesterResponderSupport.getMaxInboundPayloadSize(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + + this.payloadDecoder = null; + this.handler = null; + this.frames = null; + } + + @Override + public void onSubscribe(Subscription subscription) { + if (Operators.validate(this.s, subscription)) { + final long firstRequest = this.firstRequest; + S.lazySet(this, subscription); + subscription.request(firstRequest); + } + } + + @Override + public void onNext(Payload p) { + if (this.done) { + ReferenceCountUtil.safeRelease(p); + return; + } + + final int streamId = this.streamId; + final DuplexConnection sender = this.connection; + final ByteBufAllocator allocator = this.allocator; + + final int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + p.release(); + + if (!this.tryTerminateOnError()) { + return; + } + + final CanceledException e = + new CanceledException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, streamId, e); + sender.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + return; + } + } catch (IllegalReferenceCountException e) { + if (!this.tryTerminateOnError()) { + return; + } + + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException("Failed to validate payload. Cause" + e.getMessage())); + sender.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + return; + } + + try { + sendReleasingPayload(streamId, FrameType.NEXT, mtu, p, sender, allocator, false); + } catch (Throwable t) { + if (!this.tryTerminateOnError()) { + return; + } + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + } + } + + boolean tryTerminateOnError() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return false; + } + + this.done = true; + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return false; + } + + currentSubscription.cancel(); + + return true; + } + + @Override + public void onError(Throwable t) { + if (this.done) { + logger.debug("Dropped error", t); + return; + } + + this.done = true; + + if (S.getAndSet(this, Operators.cancelledSubscription()) == Operators.cancelledSubscription()) { + logger.debug("Dropped error", t); + return; + } + + final CompositeByteBuf frames = this.frames; + if (frames != null && frames.refCnt() > 0) { + frames.release(); + } + + final int streamId = this.streamId; + + final ByteBuf errorFrame = ErrorFrameCodec.encode(this.allocator, streamId, t); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + this.done = true; + + if (S.getAndSet(this, Operators.cancelledSubscription()) == Operators.cancelledSubscription()) { + return; + } + + final int streamId = this.streamId; + + final ByteBuf completeFrame = PayloadFrameCodec.encodeComplete(this.allocator, streamId); + this.connection.sendFrame(streamId, completeFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, null); + } + } + + @Override + public void handleRequestN(long n) { + this.s.request(n); + } + + @Override + public final void handleCancel() { + final Subscription currentSubscription = this.s; + if (currentSubscription == Operators.cancelledSubscription()) { + return; + } + + if (currentSubscription == null) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + final CompositeByteBuf frames = this.frames; + if (frames != null) { + this.frames = null; + frames.release(); + } + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + return; + } + + if (!S.compareAndSet(this, currentSubscription, Operators.cancelledSubscription())) { + return; + } + + final int streamId = this.streamId; + this.requesterResponderSupport.remove(streamId, this); + + currentSubscription.cancel(); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onCancel(streamId, FrameType.REQUEST_STREAM); + } + } + + @Override + public void handleNext(ByteBuf followingFrame, boolean hasFollows, boolean isLastPayload) { + final CompositeByteBuf frames = this.frames; + if (frames == null) { + return; + } + + try { + ReassemblyUtils.addFollowingFrame( + frames, followingFrame, hasFollows, this.maxInboundPayloadSize); + } catch (IllegalStateException e) { + // if subscription is null, it means that streams has not yet reassembled all the fragments + // and fragmentation of the first frame was cancelled before + S.lazySet(this, Operators.cancelledSubscription()); + + final int streamId = this.streamId; + + this.frames = null; + frames.release(); + + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + e.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, e); + } + + logger.debug("Reassembly has failed", e); + return; + } + + if (!hasFollows) { + this.frames = null; + Payload payload; + try { + payload = this.payloadDecoder.apply(frames); + frames.release(); + } catch (Throwable t) { + S.lazySet(this, Operators.cancelledSubscription()); + this.done = true; + + final int streamId = this.streamId; + + ReferenceCountUtil.safeRelease(frames); + + // sends error frame from the responder side to tell that something went wrong + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + this.allocator, + streamId, + new CanceledException("Failed to reassemble payload. Cause: " + t.getMessage())); + this.connection.sendFrame(streamId, errorFrame); + + this.requesterResponderSupport.remove(streamId, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onTerminate(streamId, FrameType.REQUEST_STREAM, t); + } + + logger.debug("Reassembly has failed", t); + return; + } + + Flux source = this.handler.requestStream(payload); + source.subscribe(this); + } + } + + @Override + public Context currentContext() { + return SendUtils.DISCARD_CONTEXT; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java new file mode 100644 index 000000000..1f7b09af8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterFrameHandler.java @@ -0,0 +1,43 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import java.util.concurrent.CancellationException; +import reactor.util.annotation.Nullable; + +interface RequesterFrameHandler extends FrameHandler { + + void handlePayload(Payload payload); + + @Override + default void handleCancel() { + handleError( + new CancellationException( + "Cancellation was received but should not be possible for current request type")); + } + + @Override + default void handleRequestN(long n) { + // no ops + } + + @Nullable + CompositeByteBuf getFrames(); + + void setFrames(@Nullable CompositeByteBuf reassembledFrames); +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java new file mode 100644 index 000000000..50da83b8f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterLeaseTracker.java @@ -0,0 +1,135 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Availability; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.lease.Lease; +import io.rsocket.lease.MissingLeaseException; +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Queue; + +final class RequesterLeaseTracker implements Availability { + + final String tag; + final int maximumAllowedAwaitingPermitHandlersNumber; + final Queue awaitingPermitHandlersQueue; + + Lease currentLease = null; + int availableRequests; + + boolean isDisposed; + Throwable t; + + RequesterLeaseTracker(String tag, int maximumAllowedAwaitingPermitHandlersNumber) { + this.tag = tag; + this.maximumAllowedAwaitingPermitHandlersNumber = maximumAllowedAwaitingPermitHandlersNumber; + this.awaitingPermitHandlersQueue = new ArrayDeque<>(); + } + + synchronized void issue(LeasePermitHandler leasePermitHandler) { + if (this.isDisposed) { + leasePermitHandler.handlePermitError(this.t); + return; + } + + final int availableRequests = this.availableRequests; + final Lease l = this.currentLease; + final boolean leaseReceived = l != null; + final boolean isExpired = leaseReceived && isExpired(l); + + if (leaseReceived && availableRequests > 0 && !isExpired) { + if (leasePermitHandler.handlePermit()) { + this.availableRequests = availableRequests - 1; + } + } else { + final Queue queue = this.awaitingPermitHandlersQueue; + if (this.maximumAllowedAwaitingPermitHandlersNumber > queue.size()) { + queue.offer(leasePermitHandler); + } else { + final String tag = this.tag; + final String message; + if (!leaseReceived) { + message = String.format("[%s] Lease was not received yet", tag); + } else if (isExpired) { + message = String.format("[%s] Missing leases. Lease is expired", tag); + } else { + message = + String.format( + "[%s] Missing leases. Issued [%s] request allowance is used", + tag, availableRequests); + } + + final Throwable t = new MissingLeaseException(message); + leasePermitHandler.handlePermitError(t); + } + } + } + + void handleLeaseFrame(ByteBuf leaseFrame) { + final int numberOfRequests = LeaseFrameCodec.numRequests(leaseFrame); + final int timeToLiveMillis = LeaseFrameCodec.ttl(leaseFrame); + final ByteBuf metadata = LeaseFrameCodec.metadata(leaseFrame); + + synchronized (this) { + final Lease lease = + Lease.create(Duration.ofMillis(timeToLiveMillis), numberOfRequests, metadata); + final Queue queue = this.awaitingPermitHandlersQueue; + + int availableRequests = lease.numberOfRequests(); + + this.currentLease = lease; + if (queue.size() > 0) { + do { + final LeasePermitHandler handler = queue.poll(); + if (handler.handlePermit()) { + availableRequests--; + } + } while (availableRequests > 0 && queue.size() > 0); + } + + this.availableRequests = availableRequests; + } + } + + public synchronized void dispose(Throwable t) { + this.isDisposed = true; + this.t = t; + + final Queue queue = this.awaitingPermitHandlersQueue; + final int size = queue.size(); + + for (int i = 0; i < size; i++) { + final LeasePermitHandler leasePermitHandler = queue.poll(); + + //noinspection ConstantConditions + leasePermitHandler.handlePermitError(t); + } + } + + @Override + public synchronized double availability() { + final Lease lease = this.currentLease; + return lease != null ? this.availableRequests / (double) lease.numberOfRequests() : 0.0d; + } + + static boolean isExpired(Lease currentLease) { + return System.currentTimeMillis() >= currentLease.expirationTime(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java new file mode 100644 index 000000000..bea7dc1aa --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/RequesterResponderSupport.java @@ -0,0 +1,161 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import java.util.Objects; +import java.util.function.Function; +import reactor.util.annotation.Nullable; + +class RequesterResponderSupport { + + private final int mtu; + private final int maxFrameLength; + private final int maxInboundPayloadSize; + private final PayloadDecoder payloadDecoder; + private final ByteBufAllocator allocator; + private final DuplexConnection connection; + @Nullable private final RequestInterceptor requestInterceptor; + + @Nullable final StreamIdSupplier streamIdSupplier; + final IntObjectMap activeStreams; + + public RequesterResponderSupport( + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + PayloadDecoder payloadDecoder, + DuplexConnection connection, + @Nullable StreamIdSupplier streamIdSupplier, + Function requestInterceptorFunction) { + + this.activeStreams = new IntObjectHashMap<>(); + this.mtu = mtu; + this.maxFrameLength = maxFrameLength; + this.maxInboundPayloadSize = maxInboundPayloadSize; + this.payloadDecoder = payloadDecoder; + this.allocator = connection.alloc(); + this.streamIdSupplier = streamIdSupplier; + this.connection = connection; + this.requestInterceptor = requestInterceptorFunction.apply((RSocket) this); + } + + public int getMtu() { + return mtu; + } + + public int getMaxFrameLength() { + return maxFrameLength; + } + + public int getMaxInboundPayloadSize() { + return maxInboundPayloadSize; + } + + public PayloadDecoder getPayloadDecoder() { + return payloadDecoder; + } + + public ByteBufAllocator getAllocator() { + return allocator; + } + + public DuplexConnection getDuplexConnection() { + return connection; + } + + @Nullable + public RequesterLeaseTracker getRequesterLeaseTracker() { + return null; + } + + @Nullable + public RequestInterceptor getRequestInterceptor() { + return requestInterceptor; + } + + /** + * Issues next {@code streamId} + * + * @return issued {@code streamId} + * @throws RuntimeException if the {@link RequesterResponderSupport} is terminated for any reason + */ + public int getNextStreamId() { + final StreamIdSupplier streamIdSupplier = this.streamIdSupplier; + if (streamIdSupplier != null) { + synchronized (this) { + return streamIdSupplier.nextStreamId(this.activeStreams); + } + } else { + throw new UnsupportedOperationException("Responder can not issue id"); + } + } + + /** + * Adds frameHandler and returns issued {@code streamId} back + * + * @param frameHandler to store + * @return issued {@code streamId} + * @throws RuntimeException if the {@link RequesterResponderSupport} is terminated for any reason + */ + public int addAndGetNextStreamId(FrameHandler frameHandler) { + final StreamIdSupplier streamIdSupplier = this.streamIdSupplier; + if (streamIdSupplier != null) { + final IntObjectMap activeStreams = this.activeStreams; + synchronized (this) { + final int streamId = streamIdSupplier.nextStreamId(activeStreams); + + activeStreams.put(streamId, frameHandler); + + return streamId; + } + } else { + throw new UnsupportedOperationException("Responder can not issue id"); + } + } + + public synchronized boolean add(int streamId, FrameHandler frameHandler) { + final IntObjectMap activeStreams = this.activeStreams; + // copy of Map.putIfAbsent(key, value) without `streamId` boxing + final FrameHandler previousHandler = activeStreams.get(streamId); + if (previousHandler == null) { + activeStreams.put(streamId, frameHandler); + return true; + } + return false; + } + + /** + * Resolves {@link FrameHandler} by {@code streamId} + * + * @param streamId used to resolve {@link FrameHandler} + * @return {@link FrameHandler} or {@code null} + */ + @Nullable + public synchronized FrameHandler get(int streamId) { + return this.activeStreams.get(streamId); + } + + /** + * Removes {@link FrameHandler} if it is present and equals to the given one + * + * @param streamId to lookup for {@link FrameHandler} + * @param frameHandler instance to check with the found one + * @return {@code true} if there is {@link FrameHandler} for the given {@code streamId} and the + * instance equals to the passed one + */ + public synchronized boolean remove(int streamId, FrameHandler frameHandler) { + final IntObjectMap activeStreams = this.activeStreams; + // copy of Map.remove(key, value) without `streamId` boxing + final FrameHandler curValue = activeStreams.get(streamId); + if (!Objects.equals(curValue, frameHandler)) { + return false; + } + activeStreams.remove(streamId); + return true; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java b/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java new file mode 100644 index 000000000..50bef5b70 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ResolvingOperator.java @@ -0,0 +1,646 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +// A copy of this class exists in io.rsocket.loadbalance + +class ResolvingOperator implements Disposable { + + static final CancellationException ON_DISPOSE = new CancellationException("Disposed"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ResolvingOperator.class, "wip"); + + volatile BiConsumer[] subscribers; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater SUBSCRIBERS = + AtomicReferenceFieldUpdater.newUpdater( + ResolvingOperator.class, BiConsumer[].class, "subscribers"); + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_UNSUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_SUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] READY = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] TERMINATED = new BiConsumer[0]; + + static final int ADDED_STATE = 0; + static final int READY_STATE = 1; + static final int TERMINATED_STATE = 2; + + T value; + Throwable t; + + public ResolvingOperator() { + + SUBSCRIBERS.lazySet(this, EMPTY_UNSUBSCRIBED); + } + + @Override + public final void dispose() { + this.terminate(ON_DISPOSE); + } + + @Override + public final boolean isDisposed() { + return this.subscribers == TERMINATED; + } + + public final boolean isPending() { + BiConsumer[] state = this.subscribers; + return state != READY && state != TERMINATED; + } + + @Nullable + public final T valueIfResolved() { + if (this.subscribers == READY) { + T value = this.value; + if (value != null) { + return value; + } + } + + return null; + } + + final void observe(BiConsumer actual) { + for (; ; ) { + final int state = this.add(actual); + + T value = this.value; + + if (state == READY_STATE) { + if (value != null) { + actual.accept(value, null); + return; + } + // value == null means racing between invalidate and this subscriber + // thus, we have to loop again + continue; + } else if (state == TERMINATED_STATE) { + actual.accept(null, this.t); + return; + } + + return; + } + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ResolvingOperator} is completed with an error a RuntimeException + * that wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@link ResolvingOperator} or {@code null} if the timeout is reached + * and the {@link ResolvingOperator} has not completed + * @throws RuntimeException if terminated with error + * @throws IllegalStateException if timed out or {@link Thread} was interrupted with {@link + * InterruptedException} + */ + @Nullable + @SuppressWarnings({"uncheked", "BusyWait"}) + public T block(@Nullable Duration timeout) { + try { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + + // connect once + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + long delay; + if (null == timeout) { + delay = 0L; + } else { + delay = System.nanoTime() + timeout.toNanos(); + } + for (; ; ) { + subscribers = this.subscribers; + + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + if (timeout != null && delay < System.nanoTime()) { + throw new IllegalStateException("Timeout on Mono blocking read"); + } + + // connect again since invalidate() has happened in between + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + Thread.sleep(1); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + + throw new IllegalStateException("Thread Interruption on Mono blocking read"); + } + } + + @SuppressWarnings("unchecked") + final void terminate(Throwable t) { + if (isDisposed()) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + // writes happens before volatile write + this.t = t; + + final BiConsumer[] subscribers = SUBSCRIBERS.getAndSet(this, TERMINATED); + if (subscribers == TERMINATED) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doOnDispose(); + + this.doFinally(); + + for (BiConsumer consumer : subscribers) { + consumer.accept(null, t); + } + } + + final void complete(T value) { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == TERMINATED) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + + for (; ; ) { + // ensures TERMINATE is going to be replaced with READY + if (SUBSCRIBERS.compareAndSet(this, subscribers, READY)) { + break; + } + + subscribers = this.subscribers; + + if (subscribers == TERMINATED) { + this.doFinally(); + return; + } + } + + this.doOnValueResolved(value); + + for (BiConsumer consumer : subscribers) { + consumer.accept(value, null); + } + } + + protected void doOnValueResolved(T value) { + // no ops + } + + final void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + if (value != null && isDisposed()) { + this.value = null; + this.doOnValueExpired(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + + final void invalidate() { + if (this.subscribers == TERMINATED) { + return; + } + + final BiConsumer[] subscribers = this.subscribers; + + if (subscribers == READY) { + // guarded section to ensure we expire value exactly once if there is racing + if (WIP.getAndIncrement(this) != 0) { + return; + } + + final T value = this.value; + if (value != null) { + this.value = null; + this.doOnValueExpired(value); + } + + int m = 1; + for (; ; ) { + if (isDisposed()) { + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + } + + SUBSCRIBERS.compareAndSet(this, READY, EMPTY_UNSUBSCRIBED); + } + } + + protected void doOnValueExpired(T value) { + // no ops + } + + protected void doOnDispose() { + // no ops + } + + public final boolean connect() { + for (; ; ) { + final BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return false; + } + + if (a == READY) { + return true; + } + + if (a != EMPTY_UNSUBSCRIBED) { + // do nothing if already started + return true; + } + + if (SUBSCRIBERS.compareAndSet(this, a, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + return true; + } + } + } + + final int add(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return TERMINATED_STATE; + } + + if (a == READY) { + return READY_STATE; + } + + int n = a.length; + @SuppressWarnings("unchecked") + BiConsumer[] b = new BiConsumer[n + 1]; + System.arraycopy(a, 0, b, 0, n); + b[n] = ps; + + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + if (a == EMPTY_UNSUBSCRIBED) { + this.doSubscribe(); + } + return ADDED_STATE; + } + } + } + + protected void doSubscribe() { + // no ops + } + + @SuppressWarnings("unchecked") + final void remove(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + int n = a.length; + if (n == 0) { + return; + } + + int j = -1; + for (int i = 0; i < n; i++) { + if (a[i] == ps) { + j = i; + break; + } + } + + if (j < 0) { + return; + } + + BiConsumer[] b; + + if (n == 1) { + b = EMPTY_SUBSCRIBED; + } else { + b = new BiConsumer[n - 1]; + System.arraycopy(a, 0, b, 0, j); + System.arraycopy(a, j + 1, b, j, n - j - 1); + } + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + return; + } + } + } + + abstract static class DeferredResolution + implements CoreSubscriber, Subscription, Scannable, BiConsumer { + + final ResolvingOperator parent; + final CoreSubscriber actual; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(DeferredResolution.class, "requested"); + + static final long STATE_SUBSCRIBED = -1; + static final long STATE_CANCELLED = Long.MIN_VALUE; + + Subscription s; + boolean done; + + DeferredResolution(ResolvingOperator parent, CoreSubscriber actual) { + this.parent = parent; + this.actual = actual; + } + + @Override + public final Context currentContext() { + return this.actual.currentContext(); + } + + @Nullable + @Override + public Object scanUnsafe(Attr key) { + long state = this.requested; + + if (key == Attr.PARENT) { + return this.s; + } + if (key == Attr.ACTUAL) { + return this.parent; + } + if (key == Attr.TERMINATED) { + return this.done; + } + if (key == Attr.CANCELLED) { + return state == STATE_CANCELLED; + } + + return null; + } + + @Override + public final void onSubscribe(Subscription s) { + final long state = this.requested; + Subscription a = this.s; + if (state == STATE_CANCELLED) { + s.cancel(); + return; + } + if (a != null) { + s.cancel(); + return; + } + + long r; + long accumulated = 0; + for (; ; ) { + r = this.requested; + + if (r == STATE_CANCELLED || r == STATE_SUBSCRIBED) { + s.cancel(); + return; + } + + this.s = s; + + long toRequest = r - accumulated; + if (toRequest > 0) { // if there is something, + s.request(toRequest); // then we do a request on the given subscription + } + accumulated = r; + + if (REQUESTED.compareAndSet(this, r, STATE_SUBSCRIBED)) { + return; + } + } + } + + @Override + public final void onNext(T payload) { + this.actual.onNext(payload); + } + + @Override + public final void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + this.done = true; + this.actual.onError(t); + } + + @Override + public final void onComplete() { + if (this.done) { + return; + } + + this.done = true; + this.actual.onComplete(); + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + long r = this.requested; // volatile read beforehand + + if (r > STATE_SUBSCRIBED) { // works only in case onSubscribe has not happened + long u; + for (; ; ) { // normal CAS loop with overflow protection + if (r == Long.MAX_VALUE) { + // if r == Long.MAX_VALUE then we dont care and we can loose this + // request just in case of racing + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + // Means increment happened before onSubscribe + return; + } else { + // Means increment happened after onSubscribe + + // update new state to see what exactly happened (onSubscribe |cancel | requestN) + r = this.requested; + + // check state (expect -1 | -2 to exit, otherwise repeat) + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_CANCELLED) { // if canceled, just exit + return; + } + + // if onSubscribe -> subscription exists (and we sure of that because volatile read + // after volatile write) so we can execute requestN on the subscription + this.s.request(n); + } + } + + public boolean isCancelled() { + return this.requested == STATE_CANCELLED; + } + + public void cancel() { + long state = REQUESTED.getAndSet(this, STATE_CANCELLED); + if (state == STATE_CANCELLED) { + return; + } + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + } + } + } + + static class MonoDeferredResolutionOperator extends Operators.MonoSubscriber + implements BiConsumer { + + final ResolvingOperator parent; + + MonoDeferredResolutionOperator(ResolvingOperator parent, CoreSubscriber actual) { + super(actual); + this.parent = parent; + } + + @Override + public void accept(T t, Throwable throwable) { + if (throwable != null) { + onError(throwable); + return; + } + + complete(t); + } + + @Override + public void cancel() { + if (!isCancelled()) { + super.cancel(); + this.parent.remove(this); + } + } + + @Override + public void onComplete() { + if (!isCancelled()) { + this.actual.onComplete(); + } + } + + @Override + public void onError(Throwable t) { + if (isCancelled()) { + Operators.onErrorDropped(t, currentContext()); + } else { + this.actual.onError(t); + } + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) return this.parent; + return super.scanUnsafe(key); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java b/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java new file mode 100644 index 000000000..27cc8db9a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ResponderFrameHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +interface ResponderFrameHandler extends FrameHandler { + + Logger logger = LoggerFactory.getLogger(ResponderFrameHandler.class); + + @Override + default void handleComplete() {} + + @Override + default void handleError(Throwable t) { + logger.debug("Dropped error", t); + handleCancel(); + } + + @Override + default void handleRequestN(long n) { + // no ops + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java b/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java new file mode 100644 index 000000000..fc7442f4a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ResponderLeaseTracker.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Availability; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.lease.Lease; +import io.rsocket.lease.LeaseSender; +import io.rsocket.lease.MissingLeaseException; +import reactor.core.Disposable; +import reactor.core.publisher.BaseSubscriber; +import reactor.util.annotation.Nullable; + +final class ResponderLeaseTracker extends BaseSubscriber + implements Disposable, Availability { + + final String tag; + final ByteBufAllocator allocator; + final DuplexConnection connection; + + @Nullable volatile MutableLease currentLease; + + ResponderLeaseTracker(String tag, DuplexConnection connection, LeaseSender leaseSender) { + this.tag = tag; + this.connection = connection; + this.allocator = connection.alloc(); + + leaseSender.send().subscribe(this); + } + + @Nullable + Throwable use() { + final MutableLease lease = this.currentLease; + final String tag = this.tag; + + if (lease == null) { + return new MissingLeaseException(String.format("[%s] Lease was not issued yet", tag)); + } + + if (isExpired(lease)) { + return new MissingLeaseException(String.format("[%s] Missing leases. Lease is expired", tag)); + } + + final int allowedRequests = lease.allowedRequests; + final int remainingRequests = lease.remainingRequests; + if (remainingRequests <= 0) { + return new MissingLeaseException( + String.format( + "[%s] Missing leases. Issued [%s] request allowance is used", tag, allowedRequests)); + } + + lease.remainingRequests = remainingRequests - 1; + + return null; + } + + @Override + protected void hookOnNext(Lease lease) { + final int allowedRequests = lease.numberOfRequests(); + final int ttl = lease.timeToLiveInMillis(); + final long expireAt = lease.expirationTime(); + + this.currentLease = new MutableLease(allowedRequests, expireAt); + this.connection.sendFrame( + 0, LeaseFrameCodec.encode(this.allocator, ttl, allowedRequests, lease.metadata())); + } + + @Override + public double availability() { + final MutableLease lease = this.currentLease; + + if (lease == null || isExpired(lease)) { + return 0; + } + + return lease.remainingRequests / (double) lease.allowedRequests; + } + + static boolean isExpired(MutableLease currentLease) { + return System.currentTimeMillis() >= currentLease.expireAt; + } + + static final class MutableLease { + final int allowedRequests; + final long expireAt; + + int remainingRequests; + + MutableLease(int allowedRequests, long expireAt) { + this.allowedRequests = allowedRequests; + this.expireAt = expireAt; + + this.remainingRequests = allowedRequests; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/Resume.java b/rsocket-core/src/main/java/io/rsocket/core/Resume.java new file mode 100644 index 000000000..fa0eedbfa --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/Resume.java @@ -0,0 +1,177 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.ResumableFramesStore; +import java.time.Duration; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Supplier; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.util.retry.Retry; + +/** + * Simple holder of configuration settings for the RSocket Resume capability. This can be used to + * configure an {@link RSocketConnector} or an {@link RSocketServer} except for {@link + * #retry(Retry)} and {@link #token(Supplier)} which apply only to the client side. + */ +public class Resume { + private static final Logger logger = LoggerFactory.getLogger(Resume.class); + + private Duration sessionDuration = Duration.ofMinutes(2); + + /* Storage */ + private boolean cleanupStoreOnKeepAlive; + private Function storeFactory; + private Duration streamTimeout = Duration.ofSeconds(10); + + /* Client only */ + private Supplier tokenSupplier = ResumeFrameCodec::generateResumeToken; + private Retry retry = + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(1)) + .maxBackoff(Duration.ofSeconds(16)) + .jitter(1.0) + .doBeforeRetry(signal -> logger.debug("Connection error", signal.failure())); + + public Resume() {} + + /** + * The maximum time for a client to keep trying to reconnect. During this time client and server + * continue to store unsent frames to keep the session warm and ready to resume. + * + *

By default this is set to 2 minutes. + * + * @param sessionDuration the max duration for a session + * @return the same instance for method chaining + */ + public Resume sessionDuration(Duration sessionDuration) { + this.sessionDuration = Objects.requireNonNull(sessionDuration); + return this; + } + + /** + * When this property is enabled, hints from {@code KEEPALIVE} frames about how much data has been + * received by the other side, is used to proactively clean frames from the {@link + * #storeFactory(Function) store}. + * + *

By default this is set to {@code false} in which case information from {@code KEEPALIVE} is + * ignored and old frames from the store are removed only when the store runs out of space. + * + * @return the same instance for method chaining + */ + public Resume cleanupStoreOnKeepAlive() { + this.cleanupStoreOnKeepAlive = true; + return this; + } + + /** + * Configure a factory to create the storage for buffering (or persisting) a window of frames that + * may need to be sent again to resume after a dropped connection. + * + *

By default {@link InMemoryResumableFramesStore} is used with its cache size set to 100,000 + * bytes. When the cache fills up, the oldest frames are gradually removed to create space for new + * ones. + * + * @param storeFactory the factory to use to create the store + * @return the same instance for method chaining + */ + public Resume storeFactory( + Function storeFactory) { + this.storeFactory = storeFactory; + return this; + } + + /** + * A {@link reactor.core.publisher.Flux#timeout(Duration) timeout} value to apply to the resumed + * session stream obtained from the {@link #storeFactory(Function) store} after a reconnect. The + * resume stream must not take longer than the specified time to emit each frame. + * + *

By default this is set to 10 seconds. + * + * @param streamTimeout the timeout value for resuming a session stream + * @return the same instance for method chaining + */ + public Resume streamTimeout(Duration streamTimeout) { + this.streamTimeout = Objects.requireNonNull(streamTimeout); + return this; + } + + /** + * Configure the logic for reconnecting. This setting is for use with {@link + * RSocketConnector#resume(Resume)} on the client side only. + * + *

By default this is set to: + * + *

{@code
+   * Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(1))
+   *     .maxBackoff(Duration.ofSeconds(16))
+   *     .jitter(1.0)
+   * }
+ * + * @param retry the {@code Retry} spec to use when attempting to reconnect + * @return the same instance for method chaining + */ + public Resume retry(Retry retry) { + this.retry = retry; + return this; + } + + /** + * Customize the generation of the resume identification token used to resume. This setting is for + * use with {@link RSocketConnector#resume(Resume)} on the client side only. + * + *

By default this is {@code ResumeFrameFlyweight::generateResumeToken}. + * + * @param supplier a custom generator for a resume identification token + * @return the same instance for method chaining + */ + public Resume token(Supplier supplier) { + this.tokenSupplier = supplier; + return this; + } + + // Package private accessors + + Duration getSessionDuration() { + return sessionDuration; + } + + boolean isCleanupStoreOnKeepAlive() { + return cleanupStoreOnKeepAlive; + } + + Function getStoreFactory(String tag) { + return storeFactory != null + ? storeFactory + : token -> new InMemoryResumableFramesStore(tag, token, 100_000); + } + + Duration getStreamTimeout() { + return streamTimeout; + } + + Retry getRetry() { + return retry; + } + + Supplier getTokenSupplier() { + return tokenSupplier; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java new file mode 100644 index 000000000..568dada2e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SendUtils.java @@ -0,0 +1,335 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.isFragmentable; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCounted; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.exceptions.CanceledException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import java.util.function.Consumer; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; + +final class SendUtils { + private static final Consumer DROPPED_ELEMENTS_CONSUMER = + data -> { + if (data instanceof ReferenceCounted) { + try { + ReferenceCounted referenceCounted = (ReferenceCounted) data; + referenceCounted.release(); + } catch (Throwable e) { + // ignored + } + } + }; + + static final Context DISCARD_CONTEXT = Operators.enableOnDiscard(null, DROPPED_ELEMENTS_CONSUMER); + + static void sendReleasingPayload( + int streamId, + FrameType frameType, + int mtu, + Payload payload, + DuplexConnection connection, + ByteBufAllocator allocator, + boolean requester) { + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? payload.metadata() : null; + final ByteBuf data = payload.data(); + + boolean fragmentable; + try { + fragmentable = isFragmentable(mtu, data, metadata, false); + } catch (IllegalReferenceCountException | NullPointerException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + if (fragmentable) { + final ByteBuf slicedData = data.slice(); + final ByteBuf slicedMetadata = hasMetadata ? metadata.slice() : Unpooled.EMPTY_BUFFER; + + final ByteBuf first; + try { + first = + FragmentationUtils.encodeFirstFragment( + allocator, mtu, frameType, streamId, hasMetadata, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + connection.sendFrame(streamId, first); + + boolean complete = frameType == FrameType.NEXT_COMPLETE; + while (slicedData.isReadable() || slicedMetadata.isReadable()) { + final ByteBuf following; + try { + following = + FragmentationUtils.encodeFollowsFragment( + allocator, mtu, streamId, complete, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, requester, true, e); + throw e; + } + connection.sendFrame(streamId, following); + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); + throw e; + } + } else { + final ByteBuf dataRetainedSlice = data.retainedSlice(); + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = hasMetadata ? metadata.retainedSlice() : null; + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + if (hasMetadata) { + metadataRetainedSlice.release(); + } + + sendTerminalFrame(streamId, frameType, connection, allocator, requester, false, e); + throw e; + } + + final ByteBuf requestFrame; + switch (frameType) { + case REQUEST_FNF: + requestFrame = + RequestFireAndForgetFrameCodec.encode( + allocator, streamId, false, metadataRetainedSlice, dataRetainedSlice); + break; + case REQUEST_RESPONSE: + requestFrame = + RequestResponseFrameCodec.encode( + allocator, streamId, false, metadataRetainedSlice, dataRetainedSlice); + break; + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + requestFrame = + PayloadFrameCodec.encode( + allocator, + streamId, + false, + frameType == FrameType.NEXT_COMPLETE, + frameType != FrameType.PAYLOAD, + metadataRetainedSlice, + dataRetainedSlice); + break; + default: + throw new IllegalArgumentException("Unsupported frame type " + frameType); + } + + connection.sendFrame(streamId, requestFrame); + } + } + + static void sendReleasingPayload( + int streamId, + FrameType frameType, + long initialRequestN, + int mtu, + Payload payload, + DuplexConnection connection, + ByteBufAllocator allocator, + boolean complete) { + + final boolean hasMetadata = payload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? payload.metadata() : null; + final ByteBuf data = payload.data(); + + boolean fragmentable; + try { + fragmentable = isFragmentable(mtu, data, metadata, true); + } catch (IllegalReferenceCountException | NullPointerException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + if (fragmentable) { + final ByteBuf slicedData = data.slice(); + final ByteBuf slicedMetadata = hasMetadata ? metadata.slice() : Unpooled.EMPTY_BUFFER; + + final ByteBuf first; + try { + first = + FragmentationUtils.encodeFirstFragment( + allocator, + mtu, + initialRequestN, + frameType, + streamId, + hasMetadata, + slicedMetadata, + slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + connection.sendFrame(streamId, first); + + while (slicedData.isReadable() || slicedMetadata.isReadable()) { + final ByteBuf following; + try { + following = + FragmentationUtils.encodeFollowsFragment( + allocator, mtu, streamId, complete, slicedMetadata, slicedData); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); + throw e; + } + connection.sendFrame(streamId, following); + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + sendTerminalFrame(streamId, frameType, connection, allocator, true, true, e); + throw e; + } + } else { + final ByteBuf dataRetainedSlice = data.retainedSlice(); + + final ByteBuf metadataRetainedSlice; + try { + metadataRetainedSlice = hasMetadata ? metadata.retainedSlice() : null; + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + dataRetainedSlice.release(); + if (hasMetadata) { + metadataRetainedSlice.release(); + } + + sendTerminalFrame(streamId, frameType, connection, allocator, true, false, e); + throw e; + } + + final ByteBuf requestFrame; + switch (frameType) { + case REQUEST_STREAM: + requestFrame = + RequestStreamFrameCodec.encode( + allocator, + streamId, + false, + initialRequestN, + metadataRetainedSlice, + dataRetainedSlice); + break; + case REQUEST_CHANNEL: + requestFrame = + RequestChannelFrameCodec.encode( + allocator, + streamId, + false, + complete, + initialRequestN, + metadataRetainedSlice, + dataRetainedSlice); + break; + default: + throw new IllegalArgumentException("Unsupported frame type " + frameType); + } + + connection.sendFrame(streamId, requestFrame); + } + } + + static void sendTerminalFrame( + int streamId, + FrameType frameType, + DuplexConnection connection, + ByteBufAllocator allocator, + boolean requester, + boolean onFollowingFrame, + Throwable t) { + + if (onFollowingFrame) { + if (requester) { + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + } else { + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + "Failed to encode fragmented " + + frameType + + " frame. Cause: " + + t.getMessage())); + connection.sendFrame(streamId, errorFrame); + } + } else { + switch (frameType) { + case NEXT_COMPLETE: + case NEXT: + case PAYLOAD: + if (requester) { + final ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, streamId); + connection.sendFrame(streamId, cancelFrame); + } else { + final ByteBuf errorFrame = + ErrorFrameCodec.encode( + allocator, + streamId, + new CanceledException( + "Failed to encode " + frameType + " frame. Cause: " + t.getMessage())); + connection.sendFrame(streamId, errorFrame); + } + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java new file mode 100644 index 000000000..5aae22e89 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/ServerSetup.java @@ -0,0 +1,165 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.keepalive.KeepAliveHandler.*; + +import io.netty.buffer.ByteBuf; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.exceptions.UnsupportedSetupException; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.keepalive.KeepAliveHandler; +import io.rsocket.resume.*; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.function.BiFunction; +import java.util.function.Function; +import reactor.core.publisher.Mono; +import reactor.util.function.Tuple2; + +abstract class ServerSetup { + + final Duration timeout; + + protected ServerSetup(Duration timeout) { + this.timeout = timeout; + } + + Mono> init(DuplexConnection connection) { + return Mono.>create( + sink -> sink.onRequest(__ -> new SetupHandlingDuplexConnection(connection, sink))) + .timeout(this.timeout) + .or(connection.onClose().then(Mono.error(ClosedChannelException::new))); + } + + abstract Mono acceptRSocketSetup( + ByteBuf frame, + DuplexConnection clientServerConnection, + BiFunction> then); + + abstract Mono acceptRSocketResume(ByteBuf frame, DuplexConnection connection); + + void dispose() {} + + void sendError(DuplexConnection duplexConnection, RSocketErrorException exception) { + duplexConnection.sendErrorAndClose(exception); + duplexConnection.receive().subscribe(); + } + + static class DefaultServerSetup extends ServerSetup { + + DefaultServerSetup(Duration timeout) { + super(timeout); + } + + @Override + public Mono acceptRSocketSetup( + ByteBuf frame, + DuplexConnection duplexConnection, + BiFunction> then) { + + if (SetupFrameCodec.resumeEnabled(frame)) { + sendError(duplexConnection, new UnsupportedSetupException("resume not supported")); + return duplexConnection.onClose(); + } else { + return then.apply(new DefaultKeepAliveHandler(), duplexConnection); + } + } + + @Override + public Mono acceptRSocketResume(ByteBuf frame, DuplexConnection duplexConnection) { + sendError(duplexConnection, new RejectedResumeException("resume not supported")); + return duplexConnection.onClose(); + } + } + + static class ResumableServerSetup extends ServerSetup { + private final SessionManager sessionManager; + private final Duration resumeSessionDuration; + private final Duration resumeStreamTimeout; + private final Function resumeStoreFactory; + private final boolean cleanupStoreOnKeepAlive; + + ResumableServerSetup( + Duration timeout, + SessionManager sessionManager, + Duration resumeSessionDuration, + Duration resumeStreamTimeout, + Function resumeStoreFactory, + boolean cleanupStoreOnKeepAlive) { + super(timeout); + this.sessionManager = sessionManager; + this.resumeSessionDuration = resumeSessionDuration; + this.resumeStreamTimeout = resumeStreamTimeout; + this.resumeStoreFactory = resumeStoreFactory; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + } + + @Override + public Mono acceptRSocketSetup( + ByteBuf frame, + DuplexConnection duplexConnection, + BiFunction> then) { + + if (SetupFrameCodec.resumeEnabled(frame)) { + ByteBuf resumeToken = SetupFrameCodec.resumeToken(frame); + + final ResumableFramesStore resumableFramesStore = resumeStoreFactory.apply(resumeToken); + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "server", resumeToken, duplexConnection, resumableFramesStore); + final ServerRSocketSession serverRSocketSession = + new ServerRSocketSession( + resumeToken, + resumableDuplexConnection, + duplexConnection, + resumableFramesStore, + resumeSessionDuration, + cleanupStoreOnKeepAlive); + + sessionManager.save(serverRSocketSession, resumeToken); + + return then.apply( + new ResumableKeepAliveHandler( + resumableDuplexConnection, serverRSocketSession, serverRSocketSession), + resumableDuplexConnection); + } else { + return then.apply(new DefaultKeepAliveHandler(), duplexConnection); + } + } + + @Override + public Mono acceptRSocketResume(ByteBuf frame, DuplexConnection duplexConnection) { + ServerRSocketSession session = sessionManager.get(ResumeFrameCodec.token(frame)); + if (session != null) { + session.resumeWith(frame, duplexConnection); + return duplexConnection.onClose(); + } else { + sendError(duplexConnection, new RejectedResumeException("unknown resume token")); + return duplexConnection.onClose(); + } + } + + @Override + public void dispose() { + sessionManager.dispose(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java new file mode 100644 index 000000000..3beedf97f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SetupHandlingDuplexConnection.java @@ -0,0 +1,176 @@ +package io.rsocket.core; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoSink; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; + +class SetupHandlingDuplexConnection extends Flux + implements DuplexConnection, CoreSubscriber, Subscription { + + final DuplexConnection source; + final MonoSink> sink; + + Subscription s; + boolean firstFrameReceived = false; + + CoreSubscriber actual; + + boolean done; + Throwable t; + + SetupHandlingDuplexConnection( + DuplexConnection source, MonoSink> sink) { + this.source = source; + this.sink = sink; + + source.receive().subscribe(this); + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public Flux receive() { + return this; + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + if (done) { + final Throwable t = this.t; + if (t == null) { + Operators.complete(actual); + } else { + Operators.error(actual, t); + } + return; + } + + this.actual = actual; + actual.onSubscribe(this); + } + + @Override + public void request(long n) { + if (n != Long.MAX_VALUE) { + actual.onError(new IllegalArgumentException("Only unbounded request is allowed")); + return; + } + + s.request(Long.MAX_VALUE); + } + + @Override + public void cancel() { + source.dispose(); + s.cancel(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + s.request(1); + } + } + + @Override + public void onNext(ByteBuf frame) { + if (!firstFrameReceived) { + firstFrameReceived = true; + sink.success(Tuples.of(frame, this)); + return; + } + + actual.onNext(frame); + } + + @Override + public void onError(Throwable t) { + if (done) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.done = true; + this.t = t; + + if (!firstFrameReceived) { + sink.error(t); + return; + } + + final CoreSubscriber actual = this.actual; + if (actual != null) { + actual.onError(t); + } + } + + @Override + public void onComplete() { + if (done) { + return; + } + + this.done = true; + + if (!firstFrameReceived) { + sink.error(new ClosedChannelException()); + return; + } + + final CoreSubscriber actual = this.actual; + if (actual != null) { + actual.onComplete(); + } + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + source.sendErrorAndClose(e); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public String toString() { + return "SetupHandlingDuplexConnection{" + "source=" + source + ", done=" + done + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java b/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java new file mode 100644 index 000000000..3035696b3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/SlowFireAndForgetRequesterMono.java @@ -0,0 +1,255 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.PayloadValidationUtils.isValid; +import static io.rsocket.core.SendUtils.sendReleasingPayload; +import static io.rsocket.core.StateUtils.isReadyToSendFirstFrame; +import static io.rsocket.core.StateUtils.isSubscribedOrTerminated; +import static io.rsocket.core.StateUtils.isTerminated; +import static io.rsocket.core.StateUtils.lazyTerminate; +import static io.rsocket.core.StateUtils.markReadyToSendFirstFrame; +import static io.rsocket.core.StateUtils.markSubscribed; +import static io.rsocket.core.StateUtils.markTerminated; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.annotation.Nullable; + +final class SlowFireAndForgetRequesterMono extends Mono + implements LeasePermitHandler, Subscription, Scannable { + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(SlowFireAndForgetRequesterMono.class, "state"); + + final Payload payload; + + final ByteBufAllocator allocator; + final int mtu; + final int maxFrameLength; + final RequesterResponderSupport requesterResponderSupport; + final DuplexConnection connection; + + @Nullable final RequesterLeaseTracker requesterLeaseTracker; + @Nullable final RequestInterceptor requestInterceptor; + + CoreSubscriber actual; + + SlowFireAndForgetRequesterMono( + Payload payload, RequesterResponderSupport requesterResponderSupport) { + this.allocator = requesterResponderSupport.getAllocator(); + this.payload = payload; + this.mtu = requesterResponderSupport.getMtu(); + this.maxFrameLength = requesterResponderSupport.getMaxFrameLength(); + this.requesterResponderSupport = requesterResponderSupport; + this.connection = requesterResponderSupport.getDuplexConnection(); + this.requestInterceptor = requesterResponderSupport.getRequestInterceptor(); + this.requesterLeaseTracker = requesterResponderSupport.getRequesterLeaseTracker(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + final RequesterLeaseTracker requesterLeaseTracker = this.requesterLeaseTracker; + final boolean leaseEnabled = requesterLeaseTracker != null; + long previousState = markSubscribed(STATE, this, !leaseEnabled); + if (isSubscribedOrTerminated(previousState)) { + final IllegalStateException e = + new IllegalStateException("FireAndForgetMono allows only a single Subscriber"); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + final Payload p = this.payload; + int mtu = this.mtu; + try { + if (!isValid(mtu, this.maxFrameLength, p, false)) { + lazyTerminate(STATE, this); + + final IllegalArgumentException e = + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, this.maxFrameLength)); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + Operators.error(actual, e); + return; + } + } catch (IllegalReferenceCountException e) { + lazyTerminate(STATE, this); + + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(e, FrameType.REQUEST_FNF, null); + } + + Operators.error(actual, e); + return; + } + + this.actual = actual; + actual.onSubscribe(this); + + if (leaseEnabled) { + requesterLeaseTracker.issue(this); + return; + } + + sendFirstFrame(p); + } + + @Override + public boolean handlePermit() { + final long previousState = markReadyToSendFirstFrame(STATE, this); + + if (isTerminated(previousState)) { + return false; + } + + sendFirstFrame(this.payload); + return true; + } + + void sendFirstFrame(Payload p) { + final CoreSubscriber actual = this.actual; + final int streamId; + try { + streamId = this.requesterResponderSupport.getNextStreamId(); + } catch (Throwable t) { + lazyTerminate(STATE, this); + + final Throwable ut = Exceptions.unwrap(t); + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(ut, FrameType.REQUEST_FNF, p.metadata()); + } + + p.release(); + + actual.onError(ut); + return; + } + + final RequestInterceptor interceptor = this.requestInterceptor; + if (interceptor != null) { + interceptor.onStart(streamId, FrameType.REQUEST_FNF, p.metadata()); + } + + try { + if (isTerminated(this.state)) { + p.release(); + + if (interceptor != null) { + interceptor.onCancel(streamId, FrameType.REQUEST_FNF); + } + + return; + } + + sendReleasingPayload( + streamId, FrameType.REQUEST_FNF, mtu, p, this.connection, this.allocator, true); + } catch (Throwable e) { + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, e); + } + + actual.onError(e); + return; + } + + lazyTerminate(STATE, this); + + if (interceptor != null) { + interceptor.onTerminate(streamId, FrameType.REQUEST_FNF, null); + } + + actual.onComplete(); + } + + @Override + public void request(long n) { + // no ops + } + + @Override + public void cancel() { + final long previousState = markTerminated(STATE, this); + + if (isTerminated(previousState)) { + return; + } + + if (!isReadyToSendFirstFrame(previousState)) { + this.payload.release(); + } + } + + @Override + public final void handlePermitError(Throwable cause) { + long previousState = markTerminated(STATE, this); + if (isTerminated(previousState)) { + Operators.onErrorDropped(cause, this.actual.currentContext()); + return; + } + + final Payload p = this.payload; + final RequestInterceptor requestInterceptor = this.requestInterceptor; + if (requestInterceptor != null) { + requestInterceptor.onReject(cause, FrameType.REQUEST_RESPONSE, p.metadata()); + } + + p.release(); + + this.actual.onError(cause); + } + + @Override + public Object scanUnsafe(Attr key) { + return null; // no particular key to be represented, still useful in hooks + } + + @Override + @NonNull + public String stepName() { + return "source(FireAndForgetMono)"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java new file mode 100644 index 000000000..2b6a0e09a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/StateUtils.java @@ -0,0 +1,493 @@ +package io.rsocket.core; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +final class StateUtils { + + /** Volatile Long Field bit mask that allows extract flags stored in the field */ + static final long FLAGS_MASK = + 0b111111111111111111111111111111111_0000000000000000000000000000000L; + /** Volatile Long Field bit mask that allows extract int RequestN stored in the field */ + static final long REQUEST_MASK = + 0b000000000000000000000000000000000_1111111111111111111111111111111L; + /** Bit Flag that indicates Requester Producer has been subscribed once */ + static final long SUBSCRIBED_FLAG = + 0b000000000000000000000000000000001_0000000000000000000000000000000L; + /** Bit Flag that indicates that the first payload in RequestChannel scenario is received */ + static final long FIRST_PAYLOAD_RECEIVED_FLAG = + 0b000000000000000000000000000000010_0000000000000000000000000000000L; + /** + * Bit Flag that indicates that the logical stream is ready to send the first initial frame + * (applicable for requester only) + */ + static final long READY_TO_SEND_FIRST_FRAME_FLAG = + 0b000000000000000000000000000000100_0000000000000000000000000000000L; + /** + * Bit Flag that indicates that sent first initial frame was sent (in case of requester) or + * consumed (if responder) + */ + static final long FIRST_FRAME_SENT_FLAG = + 0b000000000000000000000000000001000_0000000000000000000000000000000L; + /** Bit Flag that indicates that there is a frame being reassembled */ + static final long REASSEMBLING_FLAG = + 0b000000000000000000000000000010000_0000000000000000000000000000000L; + /** + * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates + * that the inbound is terminated + */ + static final long INBOUND_TERMINATED_FLAG = + 0b000000000000000000000000000100000_0000000000000000000000000000000L; + /** + * Bit Flag that indicates requestChannel stream is half terminated. In this case flag indicates + * that the outbound is terminated + */ + static final long OUTBOUND_TERMINATED_FLAG = + 0b000000000000000000000000001000000_0000000000000000000000000000000L; + /** Initial state for any request operator */ + static final long UNSUBSCRIBED_STATE = + 0b000000000000000000000000000000000_0000000000000000000000000000000L; + /** State that indicates request operator was terminated */ + static final long TERMINATED_STATE = + 0b100000000000000000000000000000000_0000000000000000000000000000000L; + + /** + * Adds (if possible) to the given state the {@link #SUBSCRIBED_FLAG} flag which indicates that + * the given stream has already been subscribed once + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been subscribed once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markSubscribed(AtomicLongFieldUpdater updater, T instance) { + return markSubscribed(updater, instance, false); + } + + /** + * Adds (if possible) to the given state the {@link #SUBSCRIBED_FLAG} flag which indicates that + * the given stream has already been subscribed once + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been subscribed once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param markPrepared indicates whether the given instance should be marked as prepared + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markSubscribed( + AtomicLongFieldUpdater updater, T instance, boolean markPrepared) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG) { + return state; + } + + if (updater.compareAndSet( + instance, + state, + state | SUBSCRIBED_FLAG | (markPrepared ? READY_TO_SEND_FIRST_FRAME_FLAG : 0))) { + return state; + } + } + } + + /** + * Indicates that the given stream has already been subscribed once + * + * @param state to check whether stream is subscribed + * @return true if the {@link #SUBSCRIBED_FLAG} flag is set + */ + static boolean isSubscribed(long state) { + return (state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #FIRST_FRAME_SENT_FLAG} flag which indicates + * that the first frame has already set and logical stream has already been established. + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been established once + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markFirstFrameSent(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & FIRST_FRAME_SENT_FLAG) == FIRST_FRAME_SENT_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | FIRST_FRAME_SENT_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the first frame which established logical stream has already been sent + * + * @param state to check whether stream is established + * @return true if the {@link #FIRST_FRAME_SENT_FLAG} flag is set + */ + static boolean isFirstFrameSent(long state) { + return (state & FIRST_FRAME_SENT_FLAG) == FIRST_FRAME_SENT_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #READY_TO_SEND_FIRST_FRAME_FLAG} flag which + * indicates that the logical stream is ready for initial frame sending. + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been marked as prepared + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReadyToSendFirstFrame(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & READY_TO_SEND_FIRST_FRAME_FLAG) == READY_TO_SEND_FIRST_FRAME_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | READY_TO_SEND_FIRST_FRAME_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the logical stream is ready for initial frame sending + * + * @param state to check whether stream is prepared for initial frame sending + * @return true if the {@link #READY_TO_SEND_FIRST_FRAME_FLAG} flag is set + */ + static boolean isReadyToSendFirstFrame(long state) { + return (state & READY_TO_SEND_FIRST_FRAME_FLAG) == READY_TO_SEND_FIRST_FRAME_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #FIRST_PAYLOAD_RECEIVED_FLAG} flag which + * indicates that the logical stream is ready for initial frame sending. + * + *

Note, the flag will not be added if the stream has already been terminated or if the stream + * has already been marked as prepared + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markFirstPayloadReceived(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & FIRST_PAYLOAD_RECEIVED_FLAG) == FIRST_PAYLOAD_RECEIVED_FLAG) { + return state; + } + + if (updater.compareAndSet(instance, state, state | FIRST_PAYLOAD_RECEIVED_FLAG)) { + return state; + } + } + } + + /** + * Indicates that the logical stream is ready for initial frame sending + * + * @param state to check whether stream is established + * @return true if the {@link #FIRST_PAYLOAD_RECEIVED_FLAG} flag is set + */ + static boolean isFirstPayloadReceived(long state) { + return (state & FIRST_PAYLOAD_RECEIVED_FLAG) == FIRST_PAYLOAD_RECEIVED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #REASSEMBLING_FLAG} flag which indicates that + * there is a payload reassembling in progress. + * + *

Note, the flag will not be added if the stream has already been terminated + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReassembling(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if (updater.compareAndSet(instance, state, state | REASSEMBLING_FLAG)) { + return state; + } + } + } + + /** + * Removes (if possible) from the given state the {@link #REASSEMBLING_FLAG} flag which indicates + * that a payload reassembly process is completed. + * + *

Note, the flag will not be removed if the stream has already been terminated + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markReassembled(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if (updater.compareAndSet(instance, state, state & ~REASSEMBLING_FLAG)) { + return state; + } + } + } + + /** + * Indicates that a payload reassembly process is completed. + * + * @param state to check whether there is reassembly in progress + * @return true if the {@link #REASSEMBLING_FLAG} flag is set + */ + static boolean isReassembling(long state) { + return (state & REASSEMBLING_FLAG) == REASSEMBLING_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #INBOUND_TERMINATED_FLAG} flag which indicates + * that an inbound channel of a bidirectional stream is terminated. + * + *

Note, this action will have no effect if the stream has already been terminated or if + * the {@link #INBOUND_TERMINATED_FLAG} flag has already been set.
+ * Note, if the outbound stream has already been terminated, then the result state will be + * {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markInboundTerminated(AtomicLongFieldUpdater updater, T instance) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG) { + return state; + } + + if ((state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG) { + if (updater.compareAndSet(instance, state, TERMINATED_STATE)) { + return state; + } + } else { + if (updater.compareAndSet(instance, state, state | INBOUND_TERMINATED_FLAG)) { + return state; + } + } + } + } + + /** + * Indicates that a the inbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #INBOUND_TERMINATED_FLAG} set + * @return true if the {@link #INBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isInboundTerminated(long state) { + return (state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG; + } + + /** + * Adds (if possible) to the given state the {@link #OUTBOUND_TERMINATED_FLAG} flag which + * indicates that an outbound channel of a bidirectional stream is terminated. + * + *

Note, this action will have no effect if the stream has already been terminated or if + * the {@link #OUTBOUND_TERMINATED_FLAG} flag has already been set.
+ * Note, if the {@code checkEstablishment} parameter is {@code true} and the logical stream + * is not established, then the result state will be {@link #TERMINATED_STATE}
+ * Note, if the inbound stream has already been terminated, then the result state will be + * {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param checkEstablishment indicates whether {@link #FIRST_FRAME_SENT_FLAG} should be checked to + * make final decision + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markOutboundTerminated( + AtomicLongFieldUpdater updater, T instance, boolean checkEstablishment) { + for (; ; ) { + long state = updater.get(instance); + + if (state == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + if ((state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG) { + return state; + } + + if ((checkEstablishment && !isFirstFrameSent(state)) + || (state & INBOUND_TERMINATED_FLAG) == INBOUND_TERMINATED_FLAG) { + if (updater.compareAndSet(instance, state, TERMINATED_STATE)) { + return state; + } + } else { + if (updater.compareAndSet(instance, state, state | OUTBOUND_TERMINATED_FLAG)) { + return state; + } + } + } + } + + /** + * Indicates that a the outbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #OUTBOUND_TERMINATED_FLAG} set + * @return true if the {@link #OUTBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isOutboundTerminated(long state) { + return (state & OUTBOUND_TERMINATED_FLAG) == OUTBOUND_TERMINATED_FLAG; + } + + /** + * Makes current state a {@link #TERMINATED_STATE} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + * @return return previous state before setting the new one + */ + static long markTerminated(AtomicLongFieldUpdater updater, T instance) { + return updater.getAndSet(instance, TERMINATED_STATE); + } + + /** + * Makes current state a {@link #TERMINATED_STATE} using {@link + * AtomicLongFieldUpdater#lazySet(Object, long)} + * + * @param updater of the volatile state field + * @param instance instance holder of the volatile state + * @param generic type of the instance + */ + static void lazyTerminate(AtomicLongFieldUpdater updater, T instance) { + updater.lazySet(instance, TERMINATED_STATE); + } + + /** + * Indicates that a the outbound channel of a bidirectional stream is terminated. + * + * @param state to check whether it has {@link #OUTBOUND_TERMINATED_FLAG} set + * @return true if the {@link #OUTBOUND_TERMINATED_FLAG} flag is set + */ + static boolean isTerminated(long state) { + return state == TERMINATED_STATE; + } + + /** + * Shortcut for {@link #isSubscribed} {@code ||} {@link #isTerminated} methods + * + * @param state to check flags on + * @return true if state is terminated or has flag subscribed + */ + static boolean isSubscribedOrTerminated(long state) { + return state == TERMINATED_STATE || (state & SUBSCRIBED_FLAG) == SUBSCRIBED_FLAG; + } + + static long addRequestN(AtomicLongFieldUpdater updater, T instance, long toAdd) { + return addRequestN(updater, instance, toAdd, false); + } + + static long addRequestN( + AtomicLongFieldUpdater updater, T instance, long toAdd, boolean markPrepared) { + long currentState, flags, requestN, nextRequestN; + for (; ; ) { + currentState = updater.get(instance); + + if (currentState == TERMINATED_STATE) { + return TERMINATED_STATE; + } + + requestN = currentState & REQUEST_MASK; + if (requestN == REQUEST_MASK) { + return currentState; + } + + flags = (currentState & FLAGS_MASK) | (markPrepared ? READY_TO_SEND_FIRST_FRAME_FLAG : 0); + nextRequestN = addRequestN(requestN, toAdd); + + if (updater.compareAndSet(instance, currentState, nextRequestN | flags)) { + return currentState; + } + } + } + + static long addRequestN(long a, long b) { + long res = a + b; + if (res < 0 || res > REQUEST_MASK) { + return REQUEST_MASK; + } + return res; + } + + static boolean hasRequested(long state) { + return (state & REQUEST_MASK) > 0; + } + + static long extractRequestN(long state) { + long requestN = state & REQUEST_MASK; + + if (requestN == REQUEST_MASK) { + return REQUEST_MASK; + } + + return requestN; + } + + static boolean isMaxAllowedRequestN(long n) { + return n >= REQUEST_MASK; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java new file mode 100644 index 000000000..15d39c993 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/StreamIdSupplier.java @@ -0,0 +1,58 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import io.netty.util.collection.IntObjectMap; + +/** This API is not thread-safe and must be strictly used in serialized fashion */ +final class StreamIdSupplier { + private static final int MASK = 0x7FFFFFFF; + + private long streamId; + + // Visible for testing + StreamIdSupplier(int streamId) { + this.streamId = streamId; + } + + static StreamIdSupplier clientSupplier() { + return new StreamIdSupplier(-1); + } + + static StreamIdSupplier serverSupplier() { + return new StreamIdSupplier(0); + } + + /** + * This methods provides new stream id and ensures there is no intersections with already running + * streams. This methods is not thread-safe. + * + * @param streamIds currently running streams store + * @return next stream id + */ + int nextStreamId(IntObjectMap streamIds) { + int streamId; + do { + this.streamId += 2; + streamId = (int) (this.streamId & MASK); + } while (streamId == 0 || streamIds.containsKey(streamId)); + return streamId; + } + + boolean isBeforeOrCurrent(int streamId) { + return this.streamId >= streamId && streamId > 0; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/core/package-info.java b/rsocket-core/src/main/java/io/rsocket/core/package-info.java new file mode 100644 index 000000000..29db3f205 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/core/package-info.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains {@link io.rsocket.core.RSocketConnector RSocketConnector} and {@link + * io.rsocket.core.RSocketServer RSocketServer}, the main classes for connecting to or starting an + * RSocket server. + * + *

This package also contains a package private classes that implement support for the main + * RSocket interactions. + */ +@NonNullApi +package io.rsocket.core; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java new file mode 100644 index 000000000..40cb15dd6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ApplicationErrorException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * Application layer logic generating a Reactive Streams {@code onError} event. + * + * @see Error + * Codes + */ +public final class ApplicationErrorException extends RSocketErrorException { + + private static final long serialVersionUID = 7873267740343446585L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public ApplicationErrorException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public ApplicationErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java new file mode 100644 index 000000000..144ef94c6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CanceledException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The Responder canceled the request but may have started processing it (similar to REJECTED but + * doesn't guarantee lack of side-effects). + * + * @see Error + * Codes + */ +public final class CanceledException extends RSocketErrorException { + + private static final long serialVersionUID = 5074789326089722770L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public CanceledException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public CanceledException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CANCELED, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java new file mode 100644 index 000000000..1e0167bdd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionCloseException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The connection is being terminated. Sender or Receiver of this frame MUST wait for outstanding + * streams to terminate before closing the connection. New requests MAY not be accepted. + * + * @see Error + * Codes + */ +public final class ConnectionCloseException extends RSocketErrorException { + + private static final long serialVersionUID = -2214953527482377471L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public ConnectionCloseException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public ConnectionCloseException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CONNECTION_CLOSE, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java new file mode 100644 index 000000000..5cf7cff66 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/ConnectionErrorException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The connection is being terminated. Sender or Receiver of this frame MAY close the connection + * immediately without waiting for outstanding streams to terminate. + * + * @see Error + * Codes + */ +public final class ConnectionErrorException extends RSocketErrorException implements Retryable { + + private static final long serialVersionUID = 512325887785119744L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public ConnectionErrorException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public ConnectionErrorException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.CONNECTION_ERROR, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java new file mode 100644 index 000000000..a72c0ba3b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/CustomRSocketException.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +public class CustomRSocketException extends RSocketErrorException { + private static final long serialVersionUID = 7873267740343446585L; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message) { + this(errorCode, message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code. Should be in range [0x00000301-0xFFFFFFFE] + * @param message the message + * @param cause the cause of this exception + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public CustomRSocketException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); + if (errorCode > ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE + && errorCode < ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE) { + throw new IllegalArgumentException( + "Allowed errorCode value should be in range [0x00000301-0xFFFFFFFE]", this); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java new file mode 100644 index 000000000..5c6eee614 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/Exceptions.java @@ -0,0 +1,95 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import static io.rsocket.frame.ErrorFrameCodec.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.CANCELED; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.INVALID; +import static io.rsocket.frame.ErrorFrameCodec.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.UNSUPPORTED_SETUP; + +import io.netty.buffer.ByteBuf; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import java.util.Objects; + +/** Utility class that generates an exception from a frame. */ +public final class Exceptions { + + private Exceptions() {} + + /** + * Create a {@link RSocketErrorException} from a Frame that matches the error code it contains. + * + * @param frame the frame to retrieve the error code and message from + * @return a {@link RSocketErrorException} that matches the error code in the Frame + * @throws NullPointerException if {@code frame} is {@code null} + */ + public static RuntimeException from(int streamId, ByteBuf frame) { + Objects.requireNonNull(frame, "frame must not be null"); + + int errorCode = ErrorFrameCodec.errorCode(frame); + String message = ErrorFrameCodec.dataUtf8(frame); + + if (streamId == 0) { + switch (errorCode) { + case INVALID_SETUP: + return new InvalidSetupException(message); + case UNSUPPORTED_SETUP: + return new UnsupportedSetupException(message); + case REJECTED_SETUP: + return new RejectedSetupException(message); + case REJECTED_RESUME: + return new RejectedResumeException(message); + case CONNECTION_ERROR: + return new ConnectionErrorException(message); + case CONNECTION_CLOSE: + return new ConnectionCloseException(message); + default: + return new IllegalArgumentException( + String.format("Invalid Error frame in Stream ID 0: 0x%08X '%s'", errorCode, message)); + } + } else { + switch (errorCode) { + case APPLICATION_ERROR: + return new ApplicationErrorException(message); + case REJECTED: + return new RejectedException(message); + case CANCELED: + return new CanceledException(message); + case INVALID: + return new InvalidException(message); + default: + if (errorCode >= MIN_USER_ALLOWED_ERROR_CODE + || errorCode <= MAX_USER_ALLOWED_ERROR_CODE) { + return new CustomRSocketException(errorCode, message); + } + return new IllegalArgumentException( + String.format( + "Invalid Error frame in Stream ID %d: 0x%08X '%s'", + streamId, errorCode, message)); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java new file mode 100644 index 000000000..c556423b9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The request is invalid. + * + * @see Error + * Codes + */ +public final class InvalidException extends RSocketErrorException { + + private static final long serialVersionUID = 8279420324864928243L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public InvalidException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public InvalidException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.INVALID, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java new file mode 100644 index 000000000..b0889c5a6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/InvalidSetupException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The Setup frame is invalid for the server (it could be that the client is too recent for the old + * server). + * + * @see Error + * Codes + */ +public final class InvalidSetupException extends SetupException { + + private static final long serialVersionUID = -6816210006610385251L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public InvalidSetupException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public InvalidSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.INVALID_SETUP, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java new file mode 100644 index 000000000..8bc946e3d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedException.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * Despite being a valid request, the Responder decided to reject it. The Responder guarantees that + * it didn't process the request. The reason for the rejection is explained in the Error Data + * section. + * + * @see Error + * Codes + */ +public class RejectedException extends RSocketErrorException implements Retryable { + + private static final long serialVersionUID = 3926231092835143715L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public RejectedException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public RejectedException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java new file mode 100644 index 000000000..44cc55710 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedResumeException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The server rejected the resume, it can specify the reason in the payload. + * + * @see Error + * Codes + */ +public final class RejectedResumeException extends RSocketErrorException { + + private static final long serialVersionUID = -873684362478544811L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public RejectedResumeException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public RejectedResumeException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED_RESUME, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java new file mode 100644 index 000000000..c09a27e32 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/RejectedSetupException.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * The server rejected the setup, it can specify the reason in the payload. + * + * @see Error + * Codes + */ +public final class RejectedSetupException extends SetupException implements Retryable { + + private static final long serialVersionUID = 8757401529926371738L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public RejectedSetupException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public RejectedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.REJECTED_SETUP, message, cause); + } +} diff --git a/src/main/java/io/reactivesocket/ConnectionSetupHandler.java b/rsocket-core/src/main/java/io/rsocket/exceptions/Retryable.java similarity index 59% rename from src/main/java/io/reactivesocket/ConnectionSetupHandler.java rename to rsocket-core/src/main/java/io/rsocket/exceptions/Retryable.java index ced202cc9..e61fe4f97 100644 --- a/src/main/java/io/reactivesocket/ConnectionSetupHandler.java +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/Retryable.java @@ -1,11 +1,11 @@ -/** - * Copyright 2015 Netflix, Inc. +/* + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket; -import io.reactivesocket.exceptions.SetupException; +package io.rsocket.exceptions; -public interface ConnectionSetupHandler { - RequestHandler apply(ConnectionSetupPayload setupPayload, ReactiveSocket reactiveSocket) throws SetupException; // yeah, a checked exception -} \ No newline at end of file +/** + * Indicates that an exception is retryable. This interface is a marker and the strategy for + * retrying and operation that causes a {@link Retryable} to be thrown is not specified. + */ +public interface Retryable {} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java new file mode 100644 index 000000000..76dc39a59 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/SetupException.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import reactor.util.annotation.Nullable; + +/** The root of the setup exception hierarchy. */ +public abstract class SetupException extends RSocketErrorException { + + private static final long serialVersionUID = -2928269501877732756L; + + /** + * Constructs a new exception with the specified error code, message and cause. + * + * @param errorCode the RSocket protocol code + * @param message the message + * @param cause the cause of this exception + */ + public SetupException(int errorCode, String message, @Nullable Throwable cause) { + super(errorCode, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java new file mode 100644 index 000000000..7429ccd98 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/UnsupportedSetupException.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import io.rsocket.frame.ErrorFrameCodec; +import reactor.util.annotation.Nullable; + +/** + * Some (or all) of the parameters specified by the client are unsupported by the server. + * + * @see Error + * Codes + */ +public final class UnsupportedSetupException extends SetupException { + + private static final long serialVersionUID = -1892507835635323415L; + + /** + * Constructs a new exception with the specified message. + * + * @param message the message + */ + public UnsupportedSetupException(String message) { + this(message, null); + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param message the message + * @param cause the cause of this exception + */ + public UnsupportedSetupException(String message, @Nullable Throwable cause) { + super(ErrorFrameCodec.UNSUPPORTED_SETUP, message, cause); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java b/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java new file mode 100644 index 000000000..969aedded --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/exceptions/package-info.java @@ -0,0 +1,26 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * A hierarchy of exceptions that represent RSocket protocol error codes. + * + * @see Error + * Codes + */ +@NonNullApi +package io.rsocket.exceptions; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java new file mode 100644 index 000000000..d0d929f0f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/CancelFrameCodec.java @@ -0,0 +1,12 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class CancelFrameCodec { + private CancelFrameCodec() {} + + public static ByteBuf encode(final ByteBufAllocator allocator, final int streamId) { + return FrameHeaderCodec.encode(allocator, streamId, FrameType.CANCEL, 0); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java new file mode 100644 index 000000000..dcacb57dc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ErrorFrameCodec.java @@ -0,0 +1,66 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.RSocketErrorException; +import java.nio.charset.StandardCharsets; + +public class ErrorFrameCodec { + + // defined zero stream id error codes + public static final int INVALID_SETUP = 0x00000001; + public static final int UNSUPPORTED_SETUP = 0x00000002; + public static final int REJECTED_SETUP = 0x00000003; + public static final int REJECTED_RESUME = 0x00000004; + public static final int CONNECTION_ERROR = 0x00000101; + public static final int CONNECTION_CLOSE = 0x00000102; + // defined non-zero stream id error codes + public static final int APPLICATION_ERROR = 0x00000201; + public static final int REJECTED = 0x00000202; + public static final int CANCELED = 0x00000203; + public static final int INVALID = 0x00000204; + // defined user-allowed error codes range + public static final int MIN_USER_ALLOWED_ERROR_CODE = 0x00000301; + public static final int MAX_USER_ALLOWED_ERROR_CODE = 0xFFFFFFFE; + + public static ByteBuf encode( + ByteBufAllocator allocator, int streamId, Throwable t, ByteBuf data) { + ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.ERROR, 0); + + int errorCode = + t instanceof RSocketErrorException + ? ((RSocketErrorException) t).errorCode() + : APPLICATION_ERROR; + + header.writeInt(errorCode); + + return allocator.compositeBuffer(2).addComponents(true, header, data); + } + + public static ByteBuf encode(ByteBufAllocator allocator, int streamId, Throwable t) { + String message = t.getMessage() == null ? "" : t.getMessage(); + ByteBuf data = ByteBufUtil.writeUtf8(allocator, message); + return encode(allocator, streamId, t, data); + } + + public static int errorCode(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i; + } + + public static ByteBuf data(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf slice = byteBuf.slice(); + byteBuf.resetReaderIndex(); + return slice; + } + + public static String dataUtf8(ByteBuf byteBuf) { + return data(byteBuf).toString(StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java new file mode 100644 index 000000000..418926596 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ExtensionFrameCodec.java @@ -0,0 +1,67 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +public class ExtensionFrameCodec { + private ExtensionFrameCodec() {} + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + int extendedType, + @Nullable ByteBuf metadata, + ByteBuf data) { + + final boolean hasMetadata = metadata != null; + + int flags = FrameHeaderCodec.FLAGS_I; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.EXT, flags); + header.writeInt(extendedType); + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + public static int extendedType(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i; + } + + public static ByteBuf data(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + // Extended type + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.EXT, byteBuf); + + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + // Extended type + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java new file mode 100644 index 000000000..de228b271 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FragmentationCodec.java @@ -0,0 +1,19 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +/** FragmentationFlyweight is used to re-assemble frames */ +public class FragmentationCodec { + public static ByteBuf encode(final ByteBufAllocator allocator, ByteBuf header, ByteBuf data) { + return encode(allocator, header, null, data); + } + + public static ByteBuf encode( + final ByteBufAllocator allocator, ByteBuf header, @Nullable ByteBuf metadata, ByteBuf data) { + + final boolean hasMetadata = metadata != null; + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java new file mode 100644 index 000000000..ea011e503 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameBodyCodec.java @@ -0,0 +1,103 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import reactor.util.annotation.Nullable; + +class FrameBodyCodec { + public static final int FRAME_LENGTH_MASK = 0xFFFFFF; + + private FrameBodyCodec() {} + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + byte b = byteBuf.readByte(); + int length = (b & 0xFF) << 16; + byte b1 = byteBuf.readByte(); + length |= (b1 & 0xFF) << 8; + byte b2 = byteBuf.readByte(); + length |= b2 & 0xFF; + return length; + } + + static ByteBuf encode( + ByteBufAllocator allocator, + final ByteBuf header, + @Nullable ByteBuf metadata, + boolean hasMetadata, + @Nullable ByteBuf data) { + + final boolean addData; + if (data != null) { + if (data.isReadable()) { + addData = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + data.release(); + addData = false; + } + } else { + addData = false; + } + + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (hasMetadata) { + int length = metadata.readableBytes(); + encodeLength(header, length); + } + + if (addMetadata && addData) { + return allocator.compositeBuffer(3).addComponents(true, header, metadata, data); + } else if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } else if (addData) { + return allocator.compositeBuffer(2).addComponents(true, header, data); + } else { + return header; + } + } + + static ByteBuf metadataWithoutMarking(ByteBuf byteBuf) { + int length = decodeLength(byteBuf); + return byteBuf.readSlice(length); + } + + static ByteBuf dataWithoutMarking(ByteBuf byteBuf, boolean hasMetadata) { + if (hasMetadata) { + /*moves reader index*/ + int length = decodeLength(byteBuf); + byteBuf.skipBytes(length); + } + if (byteBuf.readableBytes() > 0) { + return byteBuf.readSlice(byteBuf.readableBytes()); + } else { + return Unpooled.EMPTY_BUFFER; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java new file mode 100644 index 000000000..fc146c935 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameHeaderCodec.java @@ -0,0 +1,140 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.reactivestreams.Subscriber; + +/** + * Per connection frame flyweight. + * + *

Not the latest frame layout, but close. Does not include - fragmentation / reassembly - encode + * should remove Type param and have it as part of method name (1 encode per type?) + * + *

Not thread-safe. Assumed to be used single-threaded + */ +public final class FrameHeaderCodec { + /** (I)gnore flag: a value of 0 indicates the protocol can't ignore this frame */ + public static final int FLAGS_I = 0b10_0000_0000; + /** (M)etadata flag: a value of 1 indicates the frame contains metadata */ + public static final int FLAGS_M = 0b01_0000_0000; + /** + * (F)ollows: More fragments follow this fragment (in case of fragmented REQUEST_x or PAYLOAD + * frames) + */ + public static final int FLAGS_F = 0b00_1000_0000; + /** (C)omplete: bit to indicate stream completion ({@link Subscriber#onComplete()}) */ + public static final int FLAGS_C = 0b00_0100_0000; + /** (N)ext: bit to indicate payload or metadata present ({@link Subscriber#onNext(Object)}) */ + public static final int FLAGS_N = 0b00_0010_0000; + + public static final String DISABLE_FRAME_TYPE_CHECK = "io.rsocket.frames.disableFrameTypeCheck"; + private static final int FRAME_FLAGS_MASK = 0b0000_0011_1111_1111; + private static final int FRAME_TYPE_BITS = 6; + private static final int FRAME_TYPE_SHIFT = 16 - FRAME_TYPE_BITS; + private static final int HEADER_SIZE = Integer.BYTES + Short.BYTES; + private static boolean disableFrameTypeCheck; + + static { + disableFrameTypeCheck = Boolean.getBoolean(DISABLE_FRAME_TYPE_CHECK); + } + + private FrameHeaderCodec() {} + + static ByteBuf encodeStreamZero( + final ByteBufAllocator allocator, final FrameType frameType, int flags) { + return encode(allocator, 0, frameType, flags); + } + + public static ByteBuf encode( + final ByteBufAllocator allocator, final int streamId, final FrameType frameType, int flags) { + if (!frameType.canHaveMetadata() && ((flags & FLAGS_M) == FLAGS_M)) { + throw new IllegalStateException("bad value for metadata flag"); + } + + short typeAndFlags = (short) (frameType.getEncodedType() << FRAME_TYPE_SHIFT | (short) flags); + + return allocator.buffer().writeInt(streamId).writeShort(typeAndFlags); + } + + public static boolean hasFollows(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_F) == FLAGS_F; + } + + public static boolean hasComplete(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_C) == FLAGS_C; + } + + public static int streamId(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int streamId = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return streamId; + } + + public static int flags(final ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(Integer.BYTES); + short typeAndFlags = byteBuf.readShort(); + byteBuf.resetReaderIndex(); + return typeAndFlags & FRAME_FLAGS_MASK; + } + + public static boolean hasMetadata(ByteBuf byteBuf) { + return (flags(byteBuf) & FLAGS_M) == FLAGS_M; + } + + /** + * faster version of {@link #frameType(ByteBuf)} which does not replace PAYLOAD with synthetic + * type + */ + public static FrameType nativeFrameType(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(Integer.BYTES); + int typeAndFlags = byteBuf.readShort() & 0xFFFF; + FrameType result = FrameType.fromEncodedType(typeAndFlags >> FRAME_TYPE_SHIFT); + byteBuf.resetReaderIndex(); + return result; + } + + public static FrameType frameType(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(Integer.BYTES); + int typeAndFlags = byteBuf.readShort() & 0xFFFF; + + FrameType result = FrameType.fromEncodedType(typeAndFlags >> FRAME_TYPE_SHIFT); + + if (FrameType.PAYLOAD == result) { + final int flags = typeAndFlags & FRAME_FLAGS_MASK; + + boolean complete = FLAGS_C == (flags & FLAGS_C); + boolean next = FLAGS_N == (flags & FLAGS_N); + if (next && complete) { + result = FrameType.NEXT_COMPLETE; + } else if (complete) { + result = FrameType.COMPLETE; + } else if (next) { + result = FrameType.NEXT; + } else { + throw new IllegalArgumentException("Payload must set either or both of NEXT and COMPLETE."); + } + } + + byteBuf.resetReaderIndex(); + + return result; + } + + public static void ensureFrameType(final FrameType frameType, ByteBuf byteBuf) { + if (!disableFrameTypeCheck) { + final FrameType typeInFrame = frameType(byteBuf); + + if (typeInFrame != frameType) { + throw new AssertionError("expected " + frameType + ", but saw " + typeInFrame); + } + } + } + + public static int size() { + return HEADER_SIZE; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java new file mode 100644 index 000000000..f6c19c8ee --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameLengthCodec.java @@ -0,0 +1,54 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +/** + * Some transports like TCP aren't framed, and require a length. This is used by DuplexConnections + * for transports that need to send length + */ +public class FrameLengthCodec { + public static final int FRAME_LENGTH_MASK = 0xFFFFFF; + public static final int FRAME_LENGTH_SIZE = 3; + + private FrameLengthCodec() {} + + private static void encodeLength(final ByteBuf byteBuf, final int length) { + if ((length & ~FRAME_LENGTH_MASK) != 0) { + throw new IllegalArgumentException("Length is larger than 24 bits"); + } + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(length >> 16); + byteBuf.writeByte(length >> 8); + byteBuf.writeByte(length); + } + + private static int decodeLength(final ByteBuf byteBuf) { + int length = (byteBuf.readByte() & 0xFF) << 16; + length |= (byteBuf.readByte() & 0xFF) << 8; + length |= byteBuf.readByte() & 0xFF; + return length; + } + + public static ByteBuf encode(ByteBufAllocator allocator, int length, ByteBuf frame) { + ByteBuf buffer = allocator.buffer(); + encodeLength(buffer, length); + return allocator.compositeBuffer(2).addComponents(true, buffer, frame); + } + + public static int length(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int length = decodeLength(byteBuf); + byteBuf.resetReaderIndex(); + return length; + } + + public static ByteBuf frame(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + byteBuf.skipBytes(3); + ByteBuf slice = byteBuf.slice(); + byteBuf.resetReaderIndex(); + return slice; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameType.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameType.java new file mode 100644 index 000000000..8ac743f87 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameType.java @@ -0,0 +1,315 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import java.util.Arrays; + +/** + * Types of Frame that can be sent. + * + * @see Frame + * Types + */ +public enum FrameType { + + /** Reserved. */ + RESERVED(0x00), + + // CONNECTION + + /** + * Sent by client to initiate protocol processing. + * + * @see Setup + * Frame + */ + SETUP(0x01, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA), + + /** + * Sent by Responder to grant the ability to send requests. + * + * @see Lease + * Frame + */ + LEASE(0x02, Flags.CAN_HAVE_METADATA), + + /** + * Connection keepalive. + * + * @see Keepalive + * Frame + */ + KEEPALIVE(0x03, Flags.CAN_HAVE_DATA), + + // START REQUEST + + /** + * Request single response. + * + * @see Request + * Response Frame + */ + REQUEST_RESPONSE( + 0x04, + Flags.CAN_HAVE_DATA + | Flags.CAN_HAVE_METADATA + | Flags.IS_FRAGMENTABLE + | Flags.IS_REQUEST_TYPE), + + /** + * A single one-way message. + * + * @see Request + * Fire-and-Forget Frame + */ + REQUEST_FNF( + 0x05, + Flags.CAN_HAVE_DATA + | Flags.CAN_HAVE_METADATA + | Flags.IS_FRAGMENTABLE + | Flags.IS_REQUEST_TYPE), + + /** + * Request a completable stream. + * + * @see Request + * Stream Frame + */ + REQUEST_STREAM( + 0x06, + Flags.CAN_HAVE_METADATA + | Flags.CAN_HAVE_DATA + | Flags.HAS_INITIAL_REQUEST_N + | Flags.IS_FRAGMENTABLE + | Flags.IS_REQUEST_TYPE), + + /** + * Request a completable stream in both directions. + * + * @see Request + * Channel Frame + */ + REQUEST_CHANNEL( + 0x07, + Flags.CAN_HAVE_METADATA + | Flags.CAN_HAVE_DATA + | Flags.HAS_INITIAL_REQUEST_N + | Flags.IS_FRAGMENTABLE + | Flags.IS_REQUEST_TYPE), + + // DURING REQUEST + + /** + * Request N more items with Reactive Streams semantics. + * + * @see RequestN + * Frame + */ + REQUEST_N(0x08), + + /** + * Cancel outstanding request. + * + * @see Cancel + * Frame + */ + CANCEL(0x09), + + // RESPONSE + + /** + * Payload on a stream. For example, response to a request, or message on a channel. + * + * @see Payload + * Frame + */ + PAYLOAD(0x0A, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA | Flags.IS_FRAGMENTABLE), + + /** + * Error at connection or application level. + * + * @see Error + * Frame + */ + ERROR(0x0B, Flags.CAN_HAVE_DATA), + + // METADATA + + /** + * Asynchronous Metadata frame. + * + * @see Metadata + * Push Frame + */ + METADATA_PUSH(0x0C, Flags.CAN_HAVE_METADATA), + + // RESUMPTION + + /** + * Replaces SETUP for Resuming Operation (optional). + * + * @see Resume + * Frame + */ + RESUME(0x0D), + + /** + * Sent in response to a RESUME if resuming operation possible (optional). + * + * @see Resume OK + * Frame + */ + RESUME_OK(0x0E), + + // SYNTHETIC PAYLOAD TYPES + + /** A {@link #PAYLOAD} frame with {@code NEXT} flag set. */ + NEXT(0xA0, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA | Flags.IS_FRAGMENTABLE), + + /** A {@link #PAYLOAD} frame with {@code COMPLETE} flag set. */ + COMPLETE(0xB0), + + /** A {@link #PAYLOAD} frame with {@code NEXT} and {@code COMPLETE} flags set. */ + NEXT_COMPLETE(0xC0, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA | Flags.IS_FRAGMENTABLE), + + /** + * Used To Extend more frame types as well as extensions. + * + * @see Extension + * Frame + */ + EXT(0x3F, Flags.CAN_HAVE_DATA | Flags.CAN_HAVE_METADATA); + + /** The size of the encoded frame type */ + static final int ENCODED_SIZE = 6; + + private static final FrameType[] FRAME_TYPES_BY_ENCODED_TYPE; + + static { + FRAME_TYPES_BY_ENCODED_TYPE = new FrameType[getMaximumEncodedType() + 1]; + + for (FrameType frameType : values()) { + FRAME_TYPES_BY_ENCODED_TYPE[frameType.encodedType] = frameType; + } + } + + private final int encodedType; + private final int flags; + + FrameType(int encodedType) { + this(encodedType, Flags.EMPTY); + } + + FrameType(int encodedType, int flags) { + this.encodedType = encodedType; + this.flags = flags; + } + + /** + * Returns the {@code FrameType} that matches the specified {@code encodedType}. + * + * @param encodedType the encoded type + * @return the {@code FrameType} that matches the specified {@code encodedType} + */ + public static FrameType fromEncodedType(int encodedType) { + FrameType frameType = FRAME_TYPES_BY_ENCODED_TYPE[encodedType]; + + if (frameType == null) { + throw new IllegalArgumentException(String.format("Frame type %d is unknown", encodedType)); + } + + return frameType; + } + + private static int getMaximumEncodedType() { + return Arrays.stream(values()).mapToInt(frameType -> frameType.encodedType).max().orElse(0); + } + + /** + * Whether the frame type can have data. + * + * @return whether the frame type can have data + */ + public boolean canHaveData() { + return Flags.CAN_HAVE_DATA == (flags & Flags.CAN_HAVE_DATA); + } + + /** + * Whether the frame type can have metadata + * + * @return whether the frame type can have metadata + */ + public boolean canHaveMetadata() { + return Flags.CAN_HAVE_METADATA == (flags & Flags.CAN_HAVE_METADATA); + } + + /** + * Returns the encoded type. + * + * @return the encoded type + */ + public int getEncodedType() { + return encodedType; + } + + /** + * Whether the frame type starts with an initial {@code requestN}. + * + * @return wether the frame type starts with an initial {@code requestN} + */ + public boolean hasInitialRequestN() { + return Flags.HAS_INITIAL_REQUEST_N == (flags & Flags.HAS_INITIAL_REQUEST_N); + } + + /** + * Whether the frame type is fragmentable. + * + * @return whether the frame type is fragmentable + */ + public boolean isFragmentable() { + return Flags.IS_FRAGMENTABLE == (flags & Flags.IS_FRAGMENTABLE); + } + + /** + * Whether the frame type is a request type. + * + * @return whether the frame type is a request type + */ + public boolean isRequestType() { + return Flags.IS_REQUEST_TYPE == (flags & Flags.IS_REQUEST_TYPE); + } + + private static class Flags { + private static final int EMPTY = 0b00000; + private static final int CAN_HAVE_DATA = 0b10000; + private static final int CAN_HAVE_METADATA = 0b01000; + private static final int IS_FRAGMENTABLE = 0b00100; + private static final int IS_REQUEST_TYPE = 0b00010; + private static final int HAS_INITIAL_REQUEST_N = 0b00001; + + private Flags() {} + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java new file mode 100644 index 000000000..d581731a3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/FrameUtil.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; + +public class FrameUtil { + + private FrameUtil() {} + + public static String toString(ByteBuf frame) { + FrameType frameType = FrameHeaderCodec.frameType(frame); + int streamId = FrameHeaderCodec.streamId(frame); + StringBuilder payload = new StringBuilder(); + + payload + .append("\nFrame => Stream ID: ") + .append(streamId) + .append(" Type: ") + .append(frameType) + .append(" Flags: 0b") + .append(Integer.toBinaryString(FrameHeaderCodec.flags(frame))) + .append(" Length: " + frame.readableBytes()); + + if (frameType.hasInitialRequestN()) { + payload.append(" InitialRequestN: ").append(RequestStreamFrameCodec.initialRequestN(frame)); + } + + if (frameType == FrameType.REQUEST_N) { + payload.append(" RequestN: ").append(RequestNFrameCodec.requestN(frame)); + } + + if (FrameHeaderCodec.hasMetadata(frame)) { + payload.append("\nMetadata:\n"); + + ByteBufUtil.appendPrettyHexDump(payload, getMetadata(frame, frameType)); + } + + payload.append("\nData:\n"); + ByteBufUtil.appendPrettyHexDump(payload, getData(frame, frameType)); + + return payload.toString(); + } + + private static ByteBuf getMetadata(ByteBuf frame, FrameType frameType) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(frame); + if (hasMetadata) { + ByteBuf metadata; + switch (frameType) { + case REQUEST_FNF: + metadata = RequestFireAndForgetFrameCodec.metadata(frame); + break; + case REQUEST_STREAM: + metadata = RequestStreamFrameCodec.metadata(frame); + break; + case REQUEST_RESPONSE: + metadata = RequestResponseFrameCodec.metadata(frame); + break; + case REQUEST_CHANNEL: + metadata = RequestChannelFrameCodec.metadata(frame); + break; + // Payload and synthetic types + case PAYLOAD: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + metadata = PayloadFrameCodec.metadata(frame); + break; + case METADATA_PUSH: + metadata = MetadataPushFrameCodec.metadata(frame); + break; + case SETUP: + metadata = SetupFrameCodec.metadata(frame); + break; + case LEASE: + metadata = LeaseFrameCodec.metadata(frame); + break; + default: + return Unpooled.EMPTY_BUFFER; + } + return metadata; + } else { + return Unpooled.EMPTY_BUFFER; + } + } + + private static ByteBuf getData(ByteBuf frame, FrameType frameType) { + ByteBuf data; + switch (frameType) { + case REQUEST_FNF: + data = RequestFireAndForgetFrameCodec.data(frame); + break; + case REQUEST_STREAM: + data = RequestStreamFrameCodec.data(frame); + break; + case REQUEST_RESPONSE: + data = RequestResponseFrameCodec.data(frame); + break; + case REQUEST_CHANNEL: + data = RequestChannelFrameCodec.data(frame); + break; + // Payload, KeepAlive and synthetic types + case PAYLOAD: + case KEEPALIVE: + case NEXT: + case NEXT_COMPLETE: + case COMPLETE: + data = PayloadFrameCodec.data(frame); + break; + case SETUP: + data = SetupFrameCodec.data(frame); + break; + default: + return Unpooled.EMPTY_BUFFER; + } + return data; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java new file mode 100644 index 000000000..56a93d869 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/GenericFrameCodec.java @@ -0,0 +1,159 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +class GenericFrameCodec { + + static ByteBuf encodeReleasingPayload( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean complete, + boolean next, + final Payload payload) { + return encodeReleasingPayload(allocator, frameType, streamId, complete, next, 0, payload); + } + + static ByteBuf encodeReleasingPayload( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean complete, + boolean next, + int requestN, + final Payload payload) { + + // if refCnt exceptions throws here it is safe to do no-op + boolean hasMetadata = payload.hasMetadata(); + // if refCnt exceptions throws here it is safe to do no-op still + final ByteBuf metadata = hasMetadata ? payload.metadata().retain() : null; + final ByteBuf data; + // retaining data safely. May throw either NPE or RefCntE + try { + data = payload.data().retain(); + } catch (IllegalReferenceCountException | NullPointerException e) { + if (hasMetadata) { + metadata.release(); + } + throw e; + } + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + data.release(); + if (hasMetadata) { + metadata.release(); + } + throw e; + } + + return encode(allocator, frameType, streamId, false, complete, next, requestN, metadata, data); + } + + static ByteBuf encode( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + return encode(allocator, frameType, streamId, fragmentFollows, false, false, 0, metadata, data); + } + + static ByteBuf encode( + final ByteBufAllocator allocator, + final FrameType frameType, + final int streamId, + boolean fragmentFollows, + boolean complete, + boolean next, + int requestN, + @Nullable ByteBuf metadata, + @Nullable ByteBuf data) { + + final boolean hasMetadata = metadata != null; + + int flags = 0; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + if (fragmentFollows) { + flags |= FrameHeaderCodec.FLAGS_F; + } + + if (complete) { + flags |= FrameHeaderCodec.FLAGS_C; + } + + if (next) { + flags |= FrameHeaderCodec.FLAGS_N; + } + + final ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, frameType, flags); + + if (requestN > 0) { + header.writeInt(requestN); + } + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + static ByteBuf data(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + int idx = byteBuf.readerIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.readerIndex(idx); + return data; + } + + @Nullable + static ByteBuf metadata(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + static ByteBuf dataWithRequestN(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + @Nullable + static ByteBuf metadataWithRequestN(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + static int initialRequestN(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int i = byteBuf.skipBytes(FrameHeaderCodec.size()).readInt(); + byteBuf.resetReaderIndex(); + return i; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java new file mode 100644 index 000000000..752d5b3eb --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/KeepAliveFrameCodec.java @@ -0,0 +1,56 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class KeepAliveFrameCodec { + /** + * (R)espond: Set by the sender of the KEEPALIVE, to which the responder MUST reply with a + * KEEPALIVE without the R flag set + */ + public static final int FLAGS_KEEPALIVE_R = 0b00_1000_0000; + + public static final long LAST_POSITION_MASK = 0x8000000000000000L; + + private KeepAliveFrameCodec() {} + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final boolean respond, + final long lastPosition, + final ByteBuf data) { + final int flags = respond ? FLAGS_KEEPALIVE_R : 0; + ByteBuf header = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.KEEPALIVE, flags); + + long lp = 0; + if (lastPosition > 0) { + lp |= lastPosition; + } + + header.writeLong(lp); + + return FrameBodyCodec.encode(allocator, header, null, false, data); + } + + public static boolean respondFlag(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + int flags = FrameHeaderCodec.flags(byteBuf); + return (flags & FLAGS_KEEPALIVE_R) == FLAGS_KEEPALIVE_R; + } + + public static long lastPosition(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + byteBuf.markReaderIndex(); + long l = byteBuf.skipBytes(FrameHeaderCodec.size()).readLong(); + byteBuf.resetReaderIndex(); + return l; + } + + public static ByteBuf data(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.KEEPALIVE, byteBuf); + byteBuf.markReaderIndex(); + ByteBuf slice = byteBuf.skipBytes(FrameHeaderCodec.size() + Long.BYTES).slice(); + byteBuf.resetReaderIndex(); + return slice; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java new file mode 100644 index 000000000..f20c25d3b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/LeaseFrameCodec.java @@ -0,0 +1,83 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import reactor.util.annotation.Nullable; + +public class LeaseFrameCodec { + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final int ttl, + final int numRequests, + @Nullable final ByteBuf metadata) { + + final boolean hasMetadata = metadata != null; + + int flags = 0; + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = + FrameHeaderCodec.encodeStreamZero(allocator, FrameType.LEASE, flags) + .writeInt(ttl) + .writeInt(numRequests); + + final boolean addMetadata; + if (hasMetadata) { + if (metadata.isReadable()) { + addMetadata = true; + } else { + // even though there is nothing to read, we still have to release here since nobody else + // going to do soo + metadata.release(); + addMetadata = false; + } + } else { + // has no metadata means it is null, thus no need to release anything + addMetadata = false; + } + + if (addMetadata) { + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } else { + return header; + } + } + + public static int ttl(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int ttl = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return ttl; + } + + public static int numRequests(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + byteBuf.markReaderIndex(); + // Ttl + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES); + int numRequests = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return numRequests; + } + + @Nullable + public static ByteBuf metadata(final ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.LEASE, byteBuf); + if (FrameHeaderCodec.hasMetadata(byteBuf)) { + byteBuf.markReaderIndex(); + // Ttl + Num of requests + byteBuf.skipBytes(FrameHeaderCodec.size() + Integer.BYTES * 2); + ByteBuf metadata = byteBuf.slice(); + byteBuf.resetReaderIndex(); + return metadata; + } else { + return null; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java new file mode 100644 index 000000000..d8ffe3eef --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/MetadataPushFrameCodec.java @@ -0,0 +1,43 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; + +public class MetadataPushFrameCodec { + + public static ByteBuf encodeReleasingPayload(ByteBufAllocator allocator, Payload payload) { + if (!payload.hasMetadata()) { + throw new IllegalStateException( + "Metadata push requires to have metadata present" + " in the given Payload"); + } + final ByteBuf metadata = payload.metadata().retain(); + // releasing payload safely since it can be already released wheres we have to release retained + // data and metadata as well + try { + payload.release(); + } catch (IllegalReferenceCountException e) { + metadata.release(); + throw e; + } + return encode(allocator, metadata); + } + + public static ByteBuf encode(ByteBufAllocator allocator, ByteBuf metadata) { + ByteBuf header = + FrameHeaderCodec.encodeStreamZero( + allocator, FrameType.METADATA_PUSH, FrameHeaderCodec.FLAGS_M); + return allocator.compositeBuffer(2).addComponents(true, header, metadata); + } + + public static ByteBuf metadata(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int headerSize = FrameHeaderCodec.size(); + int metadataLength = byteBuf.readableBytes() - headerSize; + byteBuf.skipBytes(headerSize); + ByteBuf metadata = byteBuf.readSlice(metadataLength); + byteBuf.resetReaderIndex(); + return metadata; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java new file mode 100644 index 000000000..1ae9c6750 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/PayloadFrameCodec.java @@ -0,0 +1,56 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class PayloadFrameCodec { + + private PayloadFrameCodec() {} + + public static ByteBuf encodeNextReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return encodeReleasingPayload(allocator, streamId, false, payload); + } + + public static ByteBuf encodeNextCompleteReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return encodeReleasingPayload(allocator, streamId, true, payload); + } + + static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, boolean complete, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.PAYLOAD, streamId, complete, true, payload); + } + + public static ByteBuf encodeComplete(ByteBufAllocator allocator, int streamId) { + return encode(allocator, streamId, false, true, false, null, null); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + boolean complete, + boolean next, + @Nullable ByteBuf metadata, + @Nullable ByteBuf data) { + + return GenericFrameCodec.encode( + allocator, FrameType.PAYLOAD, streamId, fragmentFollows, complete, next, 0, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java new file mode 100644 index 000000000..60906083d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestChannelFrameCodec.java @@ -0,0 +1,69 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestChannelFrameCodec { + + private RequestChannelFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, + int streamId, + boolean complete, + long initialRequestN, + Payload payload) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_CHANNEL, streamId, complete, false, reqN, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + boolean complete, + long initialRequestN, + @Nullable ByteBuf metadata, + ByteBuf data) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encode( + allocator, + FrameType.REQUEST_CHANNEL, + streamId, + fragmentFollows, + complete, + false, + reqN, + metadata, + data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.dataWithRequestN(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadataWithRequestN(byteBuf); + } + + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = GenericFrameCodec.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java new file mode 100644 index 000000000..b91199179 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestFireAndForgetFrameCodec.java @@ -0,0 +1,38 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestFireAndForgetFrameCodec { + + private RequestFireAndForgetFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_FNF, streamId, false, false, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + + return GenericFrameCodec.encode( + allocator, FrameType.REQUEST_FNF, streamId, fragmentFollows, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java new file mode 100644 index 000000000..66bdd46f4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestNFrameCodec.java @@ -0,0 +1,30 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class RequestNFrameCodec { + private RequestNFrameCodec() {} + + public static ByteBuf encode( + final ByteBufAllocator allocator, final int streamId, long requestN) { + + if (requestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) requestN; + + ByteBuf header = FrameHeaderCodec.encode(allocator, streamId, FrameType.REQUEST_N, 0); + return header.writeInt(reqN); + } + + public static long requestN(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.REQUEST_N, byteBuf); + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int i = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + return i == Integer.MAX_VALUE ? Long.MAX_VALUE : i; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java new file mode 100644 index 000000000..4a37acfd5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestResponseFrameCodec.java @@ -0,0 +1,37 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestResponseFrameCodec { + + private RequestResponseFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, Payload payload) { + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_RESPONSE, streamId, false, false, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + @Nullable ByteBuf metadata, + ByteBuf data) { + return GenericFrameCodec.encode( + allocator, FrameType.REQUEST_RESPONSE, streamId, fragmentFollows, metadata, data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.data(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadata(byteBuf); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java new file mode 100644 index 000000000..2f5dbf0d8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/RequestStreamFrameCodec.java @@ -0,0 +1,64 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import reactor.util.annotation.Nullable; + +public class RequestStreamFrameCodec { + + private RequestStreamFrameCodec() {} + + public static ByteBuf encodeReleasingPayload( + ByteBufAllocator allocator, int streamId, long initialRequestN, Payload payload) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encodeReleasingPayload( + allocator, FrameType.REQUEST_STREAM, streamId, false, false, reqN, payload); + } + + public static ByteBuf encode( + ByteBufAllocator allocator, + int streamId, + boolean fragmentFollows, + long initialRequestN, + @Nullable ByteBuf metadata, + ByteBuf data) { + + if (initialRequestN < 1) { + throw new IllegalArgumentException("request n is less than 1"); + } + + int reqN = initialRequestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) initialRequestN; + + return GenericFrameCodec.encode( + allocator, + FrameType.REQUEST_STREAM, + streamId, + fragmentFollows, + false, + false, + reqN, + metadata, + data); + } + + public static ByteBuf data(ByteBuf byteBuf) { + return GenericFrameCodec.dataWithRequestN(byteBuf); + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + return GenericFrameCodec.metadataWithRequestN(byteBuf); + } + + public static long initialRequestN(ByteBuf byteBuf) { + int requestN = GenericFrameCodec.initialRequestN(byteBuf); + return requestN == Integer.MAX_VALUE ? Long.MAX_VALUE : requestN; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java new file mode 100644 index 000000000..aae89f7ab --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ResumeFrameCodec.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.util.UUID; + +public class ResumeFrameCodec { + static final int CURRENT_VERSION = SetupFrameCodec.CURRENT_VERSION; + + public static ByteBuf encode( + ByteBufAllocator allocator, + ByteBuf token, + long lastReceivedServerPos, + long firstAvailableClientPos) { + + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.RESUME, 0); + byteBuf.writeInt(CURRENT_VERSION); + token.markReaderIndex(); + byteBuf.writeShort(token.readableBytes()); + byteBuf.writeBytes(token); + token.resetReaderIndex(); + byteBuf.writeLong(lastReceivedServerPos); + byteBuf.writeLong(firstAvailableClientPos); + + return byteBuf; + } + + public static int version(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + byteBuf.skipBytes(FrameHeaderCodec.size()); + int version = byteBuf.readInt(); + byteBuf.resetReaderIndex(); + + return version; + } + + public static ByteBuf token(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + // header + version + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; + byteBuf.skipBytes(tokenPos); + // token + int tokenLength = byteBuf.readShort() & 0xFFFF; + ByteBuf token = byteBuf.readSlice(tokenLength); + byteBuf.resetReaderIndex(); + + return token; + } + + public static long lastReceivedServerPos(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + // header + version + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; + byteBuf.skipBytes(tokenPos); + // token + int tokenLength = byteBuf.readShort() & 0xFFFF; + byteBuf.skipBytes(tokenLength); + long lastReceivedServerPos = byteBuf.readLong(); + byteBuf.resetReaderIndex(); + + return lastReceivedServerPos; + } + + public static long firstAvailableClientPos(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME, byteBuf); + + byteBuf.markReaderIndex(); + // header + version + int tokenPos = FrameHeaderCodec.size() + Integer.BYTES; + byteBuf.skipBytes(tokenPos); + // token + int tokenLength = byteBuf.readShort() & 0xFFFF; + byteBuf.skipBytes(tokenLength); + // last received server position + byteBuf.skipBytes(Long.BYTES); + long firstAvailableClientPos = byteBuf.readLong(); + byteBuf.resetReaderIndex(); + + return firstAvailableClientPos; + } + + public static ByteBuf generateResumeToken() { + UUID uuid = UUID.randomUUID(); + ByteBuf bb = Unpooled.buffer(16); + bb.writeLong(uuid.getMostSignificantBits()); + bb.writeLong(uuid.getLeastSignificantBits()); + return bb; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java new file mode 100644 index 000000000..2b6951e49 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/ResumeOkFrameCodec.java @@ -0,0 +1,22 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +public class ResumeOkFrameCodec { + + public static ByteBuf encode(final ByteBufAllocator allocator, long lastReceivedClientPos) { + ByteBuf byteBuf = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.RESUME_OK, 0); + byteBuf.writeLong(lastReceivedClientPos); + return byteBuf; + } + + public static long lastReceivedClientPos(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.RESUME_OK, byteBuf); + byteBuf.markReaderIndex(); + long lastReceivedClientPosition = byteBuf.skipBytes(FrameHeaderCodec.size()).readLong(); + byteBuf.resetReaderIndex(); + + return lastReceivedClientPosition; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java new file mode 100644 index 000000000..547e2436e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/SetupFrameCodec.java @@ -0,0 +1,226 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import java.nio.charset.StandardCharsets; +import reactor.util.annotation.Nullable; + +public class SetupFrameCodec { + /** + * A flag used to indicate that the client requires connection resumption, if possible (the frame + * contains a Resume Identification Token) + */ + public static final int FLAGS_RESUME_ENABLE = 0b00_1000_0000; + + /** A flag used to indicate that the client will honor LEASE sent by the server */ + public static final int FLAGS_WILL_HONOR_LEASE = 0b00_0100_0000; + + public static final int CURRENT_VERSION = VersionCodec.encode(1, 0); + + private static final int VERSION_FIELD_OFFSET = FrameHeaderCodec.size(); + private static final int KEEPALIVE_INTERVAL_FIELD_OFFSET = VERSION_FIELD_OFFSET + Integer.BYTES; + private static final int KEEPALIVE_MAX_LIFETIME_FIELD_OFFSET = + KEEPALIVE_INTERVAL_FIELD_OFFSET + Integer.BYTES; + private static final int VARIABLE_DATA_OFFSET = + KEEPALIVE_MAX_LIFETIME_FIELD_OFFSET + Integer.BYTES; + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final boolean lease, + final int keepaliveInterval, + final int maxLifetime, + final String metadataMimeType, + final String dataMimeType, + final Payload setupPayload) { + return encode( + allocator, + lease, + keepaliveInterval, + maxLifetime, + Unpooled.EMPTY_BUFFER, + metadataMimeType, + dataMimeType, + setupPayload); + } + + public static ByteBuf encode( + final ByteBufAllocator allocator, + final boolean lease, + final int keepaliveInterval, + final int maxLifetime, + final ByteBuf resumeToken, + final String metadataMimeType, + final String dataMimeType, + final Payload setupPayload) { + + final ByteBuf data = setupPayload.sliceData(); + final boolean hasMetadata = setupPayload.hasMetadata(); + final ByteBuf metadata = hasMetadata ? setupPayload.sliceMetadata() : null; + + int flags = 0; + + if (resumeToken.readableBytes() > 0) { + flags |= FLAGS_RESUME_ENABLE; + } + + if (lease) { + flags |= FLAGS_WILL_HONOR_LEASE; + } + + if (hasMetadata) { + flags |= FrameHeaderCodec.FLAGS_M; + } + + final ByteBuf header = FrameHeaderCodec.encodeStreamZero(allocator, FrameType.SETUP, flags); + + header.writeInt(CURRENT_VERSION).writeInt(keepaliveInterval).writeInt(maxLifetime); + + if ((flags & FLAGS_RESUME_ENABLE) != 0) { + resumeToken.markReaderIndex(); + header.writeShort(resumeToken.readableBytes()).writeBytes(resumeToken); + resumeToken.resetReaderIndex(); + } + + // Write metadata mime-type + int length = ByteBufUtil.utf8Bytes(metadataMimeType); + header.writeByte(length); + ByteBufUtil.writeUtf8(header, metadataMimeType); + + // Write data mime-type + length = ByteBufUtil.utf8Bytes(dataMimeType); + header.writeByte(length); + ByteBufUtil.writeUtf8(header, dataMimeType); + + return FrameBodyCodec.encode(allocator, header, metadata, hasMetadata, data); + } + + public static int version(ByteBuf byteBuf) { + FrameHeaderCodec.ensureFrameType(FrameType.SETUP, byteBuf); + byteBuf.markReaderIndex(); + int version = byteBuf.skipBytes(VERSION_FIELD_OFFSET).readInt(); + byteBuf.resetReaderIndex(); + return version; + } + + public static String humanReadableVersion(ByteBuf byteBuf) { + int encodedVersion = version(byteBuf); + return VersionCodec.major(encodedVersion) + "." + VersionCodec.minor(encodedVersion); + } + + public static boolean isSupportedVersion(ByteBuf byteBuf) { + return CURRENT_VERSION == version(byteBuf); + } + + public static int resumeTokenLength(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int tokenLength = byteBuf.skipBytes(VARIABLE_DATA_OFFSET).readShort() & 0xFFFF; + byteBuf.resetReaderIndex(); + return tokenLength; + } + + public static int keepAliveInterval(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int keepAliveInterval = byteBuf.skipBytes(KEEPALIVE_INTERVAL_FIELD_OFFSET).readInt(); + byteBuf.resetReaderIndex(); + return keepAliveInterval; + } + + public static int keepAliveMaxLifetime(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + int keepAliveMaxLifetime = byteBuf.skipBytes(KEEPALIVE_MAX_LIFETIME_FIELD_OFFSET).readInt(); + byteBuf.resetReaderIndex(); + return keepAliveMaxLifetime; + } + + public static boolean honorLease(ByteBuf byteBuf) { + return (FLAGS_WILL_HONOR_LEASE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_WILL_HONOR_LEASE; + } + + public static boolean resumeEnabled(ByteBuf byteBuf) { + return (FLAGS_RESUME_ENABLE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_RESUME_ENABLE; + } + + public static ByteBuf resumeToken(ByteBuf byteBuf) { + if (resumeEnabled(byteBuf)) { + byteBuf.markReaderIndex(); + // header + int resumePos = + FrameHeaderCodec.size() + + + // version + Integer.BYTES + + + // keep-alive interval + Integer.BYTES + + + // keep-alive maxLifeTime + Integer.BYTES; + + int tokenLength = byteBuf.skipBytes(resumePos).readShort() & 0xFFFF; + ByteBuf resumeToken = byteBuf.readSlice(tokenLength); + byteBuf.resetReaderIndex(); + return resumeToken; + } else { + return Unpooled.EMPTY_BUFFER; + } + } + + public static String metadataMimeType(ByteBuf byteBuf) { + int skip = bytesToSkipToMimeType(byteBuf); + byteBuf.markReaderIndex(); + int length = byteBuf.skipBytes(skip).readUnsignedByte(); + String mimeType = byteBuf.slice(byteBuf.readerIndex(), length).toString(StandardCharsets.UTF_8); + byteBuf.resetReaderIndex(); + return mimeType; + } + + public static String dataMimeType(ByteBuf byteBuf) { + int skip = bytesToSkipToMimeType(byteBuf); + byteBuf.markReaderIndex(); + int metadataLength = byteBuf.skipBytes(skip).readByte(); + int dataLength = byteBuf.skipBytes(metadataLength).readByte(); + String mimeType = byteBuf.readSlice(dataLength).toString(StandardCharsets.UTF_8); + byteBuf.resetReaderIndex(); + return mimeType; + } + + @Nullable + public static ByteBuf metadata(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + if (!hasMetadata) { + return null; + } + byteBuf.markReaderIndex(); + skipToPayload(byteBuf); + ByteBuf metadata = FrameBodyCodec.metadataWithoutMarking(byteBuf); + byteBuf.resetReaderIndex(); + return metadata; + } + + public static ByteBuf data(ByteBuf byteBuf) { + boolean hasMetadata = FrameHeaderCodec.hasMetadata(byteBuf); + byteBuf.markReaderIndex(); + skipToPayload(byteBuf); + ByteBuf data = FrameBodyCodec.dataWithoutMarking(byteBuf, hasMetadata); + byteBuf.resetReaderIndex(); + return data; + } + + private static int bytesToSkipToMimeType(ByteBuf byteBuf) { + int bytesToSkip = VARIABLE_DATA_OFFSET; + if ((FLAGS_RESUME_ENABLE & FrameHeaderCodec.flags(byteBuf)) == FLAGS_RESUME_ENABLE) { + bytesToSkip += resumeTokenLength(byteBuf) + Short.BYTES; + } + return bytesToSkip; + } + + private static void skipToPayload(ByteBuf byteBuf) { + int skip = bytesToSkipToMimeType(byteBuf); + byte length = byteBuf.skipBytes(skip).readByte(); + length = byteBuf.skipBytes(length).readByte(); + byteBuf.skipBytes(length); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java b/rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java new file mode 100644 index 000000000..35e4aa86a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/VersionCodec.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +public class VersionCodec { + + public static int encode(int major, int minor) { + return (major << 16) | (minor & 0xFFFF); + } + + public static int major(int version) { + return version >> 16 & 0xFFFF; + } + + public static int minor(int version) { + return version & 0xFFFF; + } + + public static String toString(int version) { + return major(version) + "." + minor(version); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java new file mode 100644 index 000000000..0d8063e0b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/DefaultPayloadDecoder.java @@ -0,0 +1,69 @@ +package io.rsocket.frame.decoder; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.util.DefaultPayload; +import java.nio.ByteBuffer; + +/** Default Frame decoder that copies the frames contents for easy of use. */ +class DefaultPayloadDecoder implements PayloadDecoder { + + @Override + public Payload apply(ByteBuf byteBuf) { + ByteBuf m; + ByteBuf d; + FrameType type = FrameHeaderCodec.frameType(byteBuf); + switch (type) { + case REQUEST_FNF: + d = RequestFireAndForgetFrameCodec.data(byteBuf); + m = RequestFireAndForgetFrameCodec.metadata(byteBuf); + break; + case REQUEST_RESPONSE: + d = RequestResponseFrameCodec.data(byteBuf); + m = RequestResponseFrameCodec.metadata(byteBuf); + break; + case REQUEST_STREAM: + d = RequestStreamFrameCodec.data(byteBuf); + m = RequestStreamFrameCodec.metadata(byteBuf); + break; + case REQUEST_CHANNEL: + d = RequestChannelFrameCodec.data(byteBuf); + m = RequestChannelFrameCodec.metadata(byteBuf); + break; + case NEXT: + case NEXT_COMPLETE: + d = PayloadFrameCodec.data(byteBuf); + m = PayloadFrameCodec.metadata(byteBuf); + break; + case METADATA_PUSH: + d = Unpooled.EMPTY_BUFFER; + m = MetadataPushFrameCodec.metadata(byteBuf); + break; + default: + throw new IllegalArgumentException("unsupported frame type: " + type); + } + + ByteBuffer data = ByteBuffer.allocate(d.readableBytes()); + data.put(d.nioBuffer()); + data.flip(); + + if (m != null) { + ByteBuffer metadata = ByteBuffer.allocate(m.readableBytes()); + metadata.put(m.nioBuffer()); + metadata.flip(); + + return DefaultPayload.create(data, metadata); + } + + return DefaultPayload.create(data); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/PayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/PayloadDecoder.java new file mode 100644 index 000000000..197eca9b0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/PayloadDecoder.java @@ -0,0 +1,10 @@ +package io.rsocket.frame.decoder; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import java.util.function.Function; + +public interface PayloadDecoder extends Function { + PayloadDecoder DEFAULT = new DefaultPayloadDecoder(); + PayloadDecoder ZERO_COPY = new ZeroCopyPayloadDecoder(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java new file mode 100644 index 000000000..3a0dc7bb5 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/ZeroCopyPayloadDecoder.java @@ -0,0 +1,58 @@ +package io.rsocket.frame.decoder; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.util.ByteBufPayload; + +/** + * Frame decoder that decodes a frame to a payload without copying. The caller is responsible for + * for releasing the payload to free memory when they no long need it. + */ +public class ZeroCopyPayloadDecoder implements PayloadDecoder { + @Override + public Payload apply(ByteBuf byteBuf) { + ByteBuf m; + ByteBuf d; + FrameType type = FrameHeaderCodec.frameType(byteBuf); + switch (type) { + case REQUEST_FNF: + d = RequestFireAndForgetFrameCodec.data(byteBuf); + m = RequestFireAndForgetFrameCodec.metadata(byteBuf); + break; + case REQUEST_RESPONSE: + d = RequestResponseFrameCodec.data(byteBuf); + m = RequestResponseFrameCodec.metadata(byteBuf); + break; + case REQUEST_STREAM: + d = RequestStreamFrameCodec.data(byteBuf); + m = RequestStreamFrameCodec.metadata(byteBuf); + break; + case REQUEST_CHANNEL: + d = RequestChannelFrameCodec.data(byteBuf); + m = RequestChannelFrameCodec.metadata(byteBuf); + break; + case NEXT: + case NEXT_COMPLETE: + d = PayloadFrameCodec.data(byteBuf); + m = PayloadFrameCodec.metadata(byteBuf); + break; + case METADATA_PUSH: + d = Unpooled.EMPTY_BUFFER; + m = MetadataPushFrameCodec.metadata(byteBuf); + break; + default: + throw new IllegalArgumentException("unsupported frame type: " + type); + } + + return ByteBufPayload.create(d.retain(), m != null ? m.retain() : null); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java b/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java new file mode 100644 index 000000000..82e8acaf3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/decoder/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Support for encoding and decoding of RSocket frames to and from {@link io.rsocket.Payload + * Payload}. + */ +@NonNullApi +package io.rsocket.frame.decoder; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/frame/package-info.java b/rsocket-core/src/main/java/io/rsocket/frame/package-info.java new file mode 100644 index 000000000..69f6d6860 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/frame/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Support for encoding and decoding of RSocket frames to and from {@link io.rsocket.Payload + * Payload}. + */ +@NonNullApi +package io.rsocket.frame; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java new file mode 100644 index 000000000..0296b0a07 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/BaseDuplexConnection.java @@ -0,0 +1,56 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal; + +import io.netty.buffer.ByteBuf; +import io.rsocket.DuplexConnection; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +public abstract class BaseDuplexConnection implements DuplexConnection { + protected final Sinks.Empty onClose = Sinks.empty(); + protected final UnboundedProcessor sender = new UnboundedProcessor(onClose::tryEmitEmpty); + + public BaseDuplexConnection() {} + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + sender.tryEmitPrioritized(frame); + } else { + sender.tryEmitNormal(frame); + } + } + + protected abstract void doOnClose(); + + @Override + public Mono onClose() { + return onClose.asMono(); + } + + @Override + public final void dispose() { + doOnClose(); + } + + @Override + @SuppressWarnings("ConstantConditions") + public final boolean isDisposed() { + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/ClientServerInputMultiplexer.java @@ -0,0 +1 @@ + diff --git a/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java new file mode 100644 index 000000000..c96a7aed2 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/UnboundedProcessor.java @@ -0,0 +1,1167 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal; + +import io.netty.buffer.ByteBuf; +import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; +import java.util.Objects; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.stream.Stream; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.Fuseable; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.Logger; +import reactor.util.annotation.Nullable; +import reactor.util.concurrent.Queues; +import reactor.util.context.Context; + +/** + * A Processor implementation that takes a custom queue and allows only a single subscriber. + * + *

The implementation keeps the order of signals. + */ +public final class UnboundedProcessor extends Flux + implements Scannable, + Disposable, + CoreSubscriber, + Fuseable.QueueSubscription, + Fuseable { + + final Queue queue; + final Queue priorityQueue; + final Runnable onFinalizedHook; + @Nullable final Logger logger; + + boolean cancelled; + boolean done; + Throwable error; + CoreSubscriber actual; + + static final long FLAG_FINALIZED = + 0b1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_DISPOSED = + 0b0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_TERMINATED = + 0b0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_CANCELLED = + 0b0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_HAS_VALUE = + 0b0000_1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_HAS_REQUEST = + 0b0000_0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_SUBSCRIBER_READY = + 0b0000_0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long FLAG_SUBSCRIBED_ONCE = + 0b0000_0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + static final long MAX_WIP_VALUE = + 0b0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111L; + + volatile long state; + + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "state"); + + volatile int discardGuard; + + static final AtomicIntegerFieldUpdater DISCARD_GUARD = + AtomicIntegerFieldUpdater.newUpdater(UnboundedProcessor.class, "discardGuard"); + + volatile long requested; + + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(UnboundedProcessor.class, "requested"); + + ByteBuf last; + + boolean outputFused; + + public UnboundedProcessor() { + this(() -> {}); + } + + UnboundedProcessor(Logger logger) { + this(() -> {}, logger); + } + + public UnboundedProcessor(Runnable onFinalizedHook) { + this(onFinalizedHook, null); + } + + UnboundedProcessor(Runnable onFinalizedHook, @Nullable Logger logger) { + this.onFinalizedHook = onFinalizedHook; + this.queue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); + this.priorityQueue = new MpscUnboundedArrayQueue<>(Queues.SMALL_BUFFER_SIZE); + this.logger = logger; + } + + @Override + public Stream inners() { + return hasDownstreams() ? Stream.of(Scannable.from(this.actual)) : Stream.empty(); + } + + @Override + public Object scanUnsafe(Attr key) { + if (Attr.ACTUAL == key) return isSubscriberReady(this.state) ? this.actual : null; + if (Attr.BUFFERED == key) return this.queue.size() + this.priorityQueue.size(); + if (Attr.PREFETCH == key) return Integer.MAX_VALUE; + if (Attr.CANCELLED == key) { + final long state = this.state; + return isCancelled(state) || isDisposed(state); + } + + return null; + } + + public boolean tryEmitPrioritized(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } + + if (!this.priorityQueue.offer(t)) { + onError(Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext())); + release(t); + return false; + } + + final long previousState = markValueAdded(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + return true; + } + + if (isWorkInProgress(previousState)) { + return true; + } + + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_HAS_VALUE) + 1); + } + } + return true; + } + + public boolean tryEmitNormal(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } + + if (!this.queue.offer(t)) { + onError(Operators.onOperatorError(null, Exceptions.failWithOverflow(), t, currentContext())); + release(t); + return false; + } + + final long previousState = markValueAdded(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + return true; + } + + if (isWorkInProgress(previousState)) { + return true; + } + + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_HAS_VALUE) + 1); + } + } + + return true; + } + + public boolean tryEmitFinal(ByteBuf t) { + if (this.done || this.cancelled) { + release(t); + return false; + } + + this.last = t; + this.done = true; + + final long previousState = markValueAddedAndTerminated(this); + if (isFinalized(previousState)) { + this.clearSafely(); + return false; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion + this.actual.onNext(null); + this.actual.onComplete(); + return true; + } + + if (isWorkInProgress(previousState)) { + return true; + } + + drainRegular((previousState | FLAG_TERMINATED | FLAG_HAS_VALUE) + 1); + } + + return true; + } + + @Deprecated + public void onNextPrioritized(ByteBuf t) { + tryEmitPrioritized(t); + } + + @Override + @Deprecated + public void onNext(ByteBuf t) { + tryEmitNormal(t); + } + + @Override + @Deprecated + public void onError(Throwable t) { + if (this.done || this.cancelled) { + Operators.onErrorDropped(t, currentContext()); + return; + } + + this.error = t; + this.done = true; + + final long previousState = markTerminatedOrFinalized(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isCancelled(previousState) + || isTerminated(previousState)) { + Operators.onErrorDropped(t, currentContext()); + return; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion scenario + this.actual.onError(t); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (!hasValue(previousState)) { + // fast path no-values scenario + this.actual.onError(t); + return; + } + + drainRegular((previousState | FLAG_TERMINATED) + 1); + } + } + + @Override + @Deprecated + public void onComplete() { + if (this.done || this.cancelled) { + return; + } + + this.done = true; + + final long previousState = markTerminatedOrFinalized(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isCancelled(previousState) + || isTerminated(previousState)) { + return; + } + + if (isSubscriberReady(previousState)) { + if (this.outputFused) { + // fast path for fusion scenario + this.actual.onComplete(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (!hasValue(previousState)) { + this.actual.onComplete(); + return; + } + + drainRegular((previousState | FLAG_TERMINATED) + 1); + } + } + + void drainRegular(long expectedState) { + final CoreSubscriber a = this.actual; + final Queue q = this.queue; + final Queue pq = this.priorityQueue; + + for (; ; ) { + + long r = this.requested; + long e = 0L; + + boolean empty = false; + boolean done; + while (r != e) { + // done has to be read before queue.poll to ensure there was no racing: + // Thread1: <#drain>: queue.poll(null) --------------------> this.done(true) + // Thread2: ------------------> <#onNext(V)> --> <#onComplete()> + done = this.done; + + ByteBuf t = pq.poll(); + empty = t == null; + + if (empty) { + t = q.poll(); + empty = t == null; + } + + if (checkTerminated(done, empty, true, a)) { + if (!empty) { + release(t); + } + return; + } + + if (empty) { + break; + } + + a.onNext(t); + + e++; + } + + if (r == e) { + // done has to be read before queue.isEmpty to ensure there was no racing: + // Thread1: <#drain>: queue.isEmpty(true) --------------------> this.done(true) + // Thread2: --------------------> <#onNext(V)> ---> <#onComplete()> + done = this.done; + empty = q.isEmpty() && pq.isEmpty(); + + if (checkTerminated(done, empty, false, a)) { + return; + } + } + + if (e != 0 && r != Long.MAX_VALUE) { + r = REQUESTED.addAndGet(this, -e); + } + + expectedState = markWorkDone(this, expectedState, r > 0, !empty); + if (isCancelled(expectedState)) { + clearAndFinalize(this); + return; + } + + if (isDisposed(expectedState)) { + clearAndFinalize(this); + a.onError(new CancellationException("Disposed")); + return; + } + + if (!isWorkInProgress(expectedState)) { + break; + } + } + } + + boolean checkTerminated( + boolean done, boolean empty, boolean hasDemand, CoreSubscriber a) { + final long state = this.state; + if (isCancelled(state)) { + clearAndFinalize(this); + return true; + } + + if (isDisposed(state)) { + clearAndFinalize(this); + a.onError(new CancellationException("Disposed")); + return true; + } + + if (done && empty) { + if (!isTerminated(state)) { + // proactively return if volatile field is not yet set to needed state + return false; + } + final ByteBuf last = this.last; + if (last != null) { + if (!hasDemand) { + return false; + } + this.last = null; + a.onNext(last); + } + clearAndFinalize(this); + Throwable e = this.error; + if (e != null) { + a.onError(e); + } else { + a.onComplete(); + } + return true; + } + + return false; + } + + @Override + public void onSubscribe(Subscription s) { + final long state = this.state; + if (isFinalized(state) || isTerminated(state) || isCancelled(state) || isDisposed(state)) { + s.cancel(); + } else { + s.request(Long.MAX_VALUE); + } + } + + @Override + public int getPrefetch() { + return Integer.MAX_VALUE; + } + + @Override + public Context currentContext() { + return isSubscriberReady(this.state) ? this.actual.currentContext() : Context.empty(); + } + + @Override + public void subscribe(CoreSubscriber actual) { + Objects.requireNonNull(actual, "subscribe"); + long previousState = markSubscribedOnce(this); + if (isSubscribedOnce(previousState)) { + Operators.error( + actual, new IllegalStateException("UnboundedProcessor allows only a single Subscriber")); + return; + } + + if (isDisposed(previousState)) { + Operators.error(actual, new CancellationException("Disposed")); + return; + } + + actual.onSubscribe(this); + this.actual = actual; + + previousState = markSubscriberReady(this); + + if (isSubscriberReady(previousState)) { + return; + } + + if (this.outputFused) { + if (isCancelled(previousState)) { + return; + } + + if (isDisposed(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (hasValue(previousState)) { + actual.onNext(null); + } + + if (isTerminated(previousState)) { + final Throwable e = this.error; + if (e != null) { + actual.onError(e); + } else { + actual.onComplete(); + } + } + return; + } + + if (isCancelled(previousState)) { + clearAndFinalize(this); + return; + } + + if (isDisposed(previousState)) { + clearAndFinalize(this); + actual.onError(new CancellationException("Disposed")); + return; + } + + if (!hasValue(previousState)) { + if (isTerminated(previousState)) { + clearAndFinalize(this); + final Throwable e = this.error; + if (e != null) { + actual.onError(e); + } else { + actual.onComplete(); + } + } + return; + } + + if (hasRequest(previousState)) { + drainRegular((previousState | FLAG_SUBSCRIBER_READY) + 1); + } + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + if (this.outputFused) { + final long state = this.state; + if (isSubscriberReady(state)) { + this.actual.onNext(null); + } + return; + } + + Operators.addCap(REQUESTED, this, n); + + final long previousState = markRequestAdded(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (isSubscriberReady(previousState) && hasValue(previousState)) { + drainRegular((previousState | FLAG_HAS_REQUEST) + 1); + } + } + } + + @Override + public void cancel() { + this.cancelled = true; + + final long previousState = markCancelled(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (!isSubscribedOnce(previousState) || !this.outputFused) { + clearAndFinalize(this); + } + } + + @Override + @Deprecated + public void dispose() { + this.cancelled = true; + + final long previousState = markDisposed(this); + if (isWorkInProgress(previousState) + || isFinalized(previousState) + || isCancelled(previousState) + || isDisposed(previousState)) { + return; + } + + if (!isSubscribedOnce(previousState)) { + clearAndFinalize(this); + return; + } + + if (!isSubscriberReady(previousState)) { + return; + } + + if (!this.outputFused) { + clearAndFinalize(this); + this.actual.onError(new CancellationException("Disposed")); + return; + } + + if (!isTerminated(previousState)) { + this.actual.onError(new CancellationException("Disposed")); + } + } + + @Override + @Nullable + public ByteBuf poll() { + ByteBuf t = this.priorityQueue.poll(); + if (t != null) { + return t; + } + + t = this.queue.poll(); + if (t != null) { + return t; + } + + t = this.last; + if (t != null) { + this.last = null; + return t; + } + + return null; + } + + @Override + public int size() { + return this.priorityQueue.size() + this.queue.size(); + } + + @Override + public boolean isEmpty() { + return this.priorityQueue.isEmpty() && this.queue.isEmpty(); + } + + /** + * Clears all elements from queues and set state to terminate. This method MUST be called only by + * the downstream subscriber which has enabled {@link Fuseable#ASYNC} fusion with the given {@link + * UnboundedProcessor} and is and indicator that the downstream is done with draining, it has + * observed any terminal signal (ON_COMPLETE or ON_ERROR or CANCEL) and will never be interacting + * with SingleConsumer queue anymore. + */ + @Override + public void clear() { + clearAndFinalize(this); + } + + void clearSafely() { + if (DISCARD_GUARD.getAndIncrement(this) != 0) { + return; + } + + int missed = 1; + for (; ; ) { + clearUnsafely(); + + missed = DISCARD_GUARD.addAndGet(this, -missed); + if (missed == 0) { + break; + } + } + } + + void clearUnsafely() { + final Queue queue = this.queue; + final Queue priorityQueue = this.priorityQueue; + + final ByteBuf last = this.last; + + if (last != null) { + release(last); + } + + ByteBuf byteBuf; + while ((byteBuf = queue.poll()) != null) { + release(byteBuf); + } + + while ((byteBuf = priorityQueue.poll()) != null) { + release(byteBuf); + } + } + + @Override + public int requestFusion(int requestedMode) { + if ((requestedMode & Fuseable.ASYNC) != 0) { + this.outputFused = true; + return Fuseable.ASYNC; + } + return Fuseable.NONE; + } + + @Override + public boolean isDisposed() { + return isFinalized(this.state); + } + + boolean hasDownstreams() { + final long state = this.state; + return !isTerminated(state) && isSubscriberReady(state); + } + + static void release(ByteBuf byteBuf) { + if (byteBuf.refCnt() > 0) { + try { + byteBuf.release(); + } catch (Throwable ex) { + // no ops + } + } + } + + /** + * Sets {@link #FLAG_SUBSCRIBED_ONCE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED} or {@link #FLAG_DISPOSED} are unset + * + * @return {@code true} if {@link #FLAG_SUBSCRIBED_ONCE} was successfully set + */ + static long markSubscribedOnce(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isSubscribedOnce(state)) { + return state; + } + + final long nextState = state | FLAG_SUBSCRIBED_ONCE; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mso", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_SUBSCRIBER_READY} flag if flags {@link #FLAG_FINALIZED}, {@link + * #FLAG_CANCELLED} or {@link #FLAG_DISPOSED} are unset + * + * @return previous state + */ + static long markSubscriberReady(UnboundedProcessor instance) { + for (; ; ) { + long state = instance.state; + + if (isFinalized(state) + || isCancelled(state) + || isDisposed(state) + || isSubscriberReady(state)) { + return state; + } + + long nextState = state; + if (!instance.outputFused) { + if ((!hasValue(state) && isTerminated(state)) || (hasRequest(state) && hasValue(state))) { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_SUBSCRIBER_READY; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " msr", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_HAS_REQUEST} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) + * + * @return previous state + */ + static long markRequestAdded(UnboundedProcessor instance) { + for (; ; ) { + long state = instance.state; + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state) || (isSubscriberReady(state) && hasValue(state))) { + nextState = addWork(state); + } + + nextState = nextState | FLAG_HAS_REQUEST; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mra", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_HAS_VALUE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) if {@link #FLAG_HAS_REQUEST} is set + * + * @return previous state + */ + static long markValueAdded(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state)) { + if (instance.outputFused) { + // fast path for fusion scenario + return state; + } + + if (hasRequest(state)) { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_HAS_VALUE; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mva", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_HAS_VALUE} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) if {@link #FLAG_HAS_REQUEST} is set + * + * @return previous state + */ + static long markValueAddedAndTerminated(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state) && !instance.outputFused) { + nextState = addWork(state); + } + + nextState = nextState | FLAG_HAS_VALUE | FLAG_TERMINATED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, "mva&t", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_TERMINATED} flag if it was not set before and if flags {@link + * #FLAG_FINALIZED}, {@link #FLAG_CANCELLED}, {@link #FLAG_DISPOSED} are unset. Also, this method + * increments number of work in progress (WIP) + * + * @return previous state + */ + static long markTerminatedOrFinalized(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isTerminated(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + long nextState = state; + if (isWorkInProgress(state)) { + nextState = addWork(state); + } else if (isSubscriberReady(state) && !instance.outputFused) { + if (!hasValue(state)) { + // fast path for no values and no work in progress + nextState = FLAG_FINALIZED; + } else { + nextState = addWork(state); + } + } + + nextState = nextState | FLAG_TERMINATED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mt|f", state, nextState); + if (isFinalized(nextState)) { + instance.onFinalizedHook.run(); + } + return state; + } + } + } + + /** + * Sets {@link #FLAG_CANCELLED} flag if it was not set before and if flag {@link #FLAG_FINALIZED} + * is unset. Also, this method increments number of work in progress (WIP) + * + * @return previous state + */ + static long markCancelled(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isCancelled(state)) { + return state; + } + + final long nextState = addWork(state) | FLAG_CANCELLED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mc", state, nextState); + return state; + } + } + } + + /** + * Sets {@link #FLAG_DISPOSED} flag if it was not set before and if flags {@link #FLAG_FINALIZED}, + * {@link #FLAG_CANCELLED} are unset. Also, this method increments number of work in progress + * (WIP) + * + * @return previous state + */ + static long markDisposed(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + final long nextState = addWork(state) | FLAG_DISPOSED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " md", state, nextState); + return state; + } + } + } + + static long addWork(long state) { + return (state & MAX_WIP_VALUE) == MAX_WIP_VALUE ? state : state + 1; + } + + /** + * Decrements the amount of work in progress by the given amount on the given state. Fails if flag + * is {@link #FLAG_FINALIZED} is set or if fusion disabled and flags {@link #FLAG_CANCELLED} or + * {@link #FLAG_DISPOSED} are set. + * + *

Note, if fusion is enabled, the decrement should work if flags {@link #FLAG_CANCELLED} or + * {@link #FLAG_DISPOSED} are set, since, while the operator was not terminate by the downstream, + * we still have to propagate notifications that new elements are enqueued + * + * @return state after changing WIP or current state if update failed + */ + static long markWorkDone( + UnboundedProcessor instance, long expectedState, boolean hasRequest, boolean hasValue) { + for (; ; ) { + final long state = instance.state; + + if (state != expectedState) { + return state; + } + + if (isFinalized(state) || isCancelled(state) || isDisposed(state)) { + return state; + } + + final long nextState = + (state - (expectedState & MAX_WIP_VALUE)) + ^ (hasRequest ? 0 : FLAG_HAS_REQUEST) + ^ (hasValue ? 0 : FLAG_HAS_VALUE); + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " mwd", state, nextState); + return nextState; + } + } + } + + /** + * Set flag {@link #FLAG_FINALIZED} and {@link #release(ByteBuf)} all the elements from {@link + * #queue} and {@link #priorityQueue}. + * + *

This method may be called concurrently only if the given {@link UnboundedProcessor} has no + * output fusion ({@link #outputFused} {@code == true}). Otherwise this method MUST only by the + * downstream calling method {@link #clear()} + */ + static void clearAndFinalize(UnboundedProcessor instance) { + for (; ; ) { + final long state = instance.state; + + if (isFinalized(state)) { + instance.clearSafely(); + return; + } + + if (!isSubscriberReady(state) || !instance.outputFused) { + instance.clearSafely(); + } else { + instance.clearUnsafely(); + } + + long nextState = (state & ~MAX_WIP_VALUE & ~FLAG_HAS_VALUE) | FLAG_FINALIZED; + if (STATE.compareAndSet(instance, state, nextState)) { + log(instance, " c&f", state, nextState); + instance.onFinalizedHook.run(); + break; + } + } + } + + static boolean hasValue(long state) { + return (state & FLAG_HAS_VALUE) == FLAG_HAS_VALUE; + } + + static boolean hasRequest(long state) { + return (state & FLAG_HAS_REQUEST) == FLAG_HAS_REQUEST; + } + + static boolean isCancelled(long state) { + return (state & FLAG_CANCELLED) == FLAG_CANCELLED; + } + + static boolean isDisposed(long state) { + return (state & FLAG_DISPOSED) == FLAG_DISPOSED; + } + + static boolean isWorkInProgress(long state) { + return (state & MAX_WIP_VALUE) != 0; + } + + static boolean isTerminated(long state) { + return (state & FLAG_TERMINATED) == FLAG_TERMINATED; + } + + static boolean isFinalized(long state) { + return (state & FLAG_FINALIZED) == FLAG_FINALIZED; + } + + static boolean isSubscriberReady(long state) { + return (state & FLAG_SUBSCRIBER_READY) == FLAG_SUBSCRIBER_READY; + } + + static boolean isSubscribedOnce(long state) { + return (state & FLAG_SUBSCRIBED_ONCE) == FLAG_SUBSCRIBED_ONCE; + } + + static void log( + UnboundedProcessor instance, String action, long initialState, long committedState) { + log(instance, action, initialState, committedState, false); + } + + static void log( + UnboundedProcessor instance, + String action, + long initialState, + long committedState, + boolean logStackTrace) { + Logger logger = instance.logger; + if (logger == null || !logger.isTraceEnabled()) { + return; + } + + if (logStackTrace) { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + action, + Thread.currentThread().getId(), + formatState(initialState, 64), + formatState(committedState, 64)), + new RuntimeException()); + } else { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + Thread.currentThread().getId(), + formatState(initialState, 64), + formatState(committedState, 64))); + } + } + + static void log( + UnboundedProcessor instance, String action, int initialState, int committedState) { + log(instance, action, initialState, committedState, false); + } + + static void log( + UnboundedProcessor instance, + String action, + int initialState, + int committedState, + boolean logStackTrace) { + Logger logger = instance.logger; + if (logger == null || !logger.isTraceEnabled()) { + return; + } + + if (logStackTrace) { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + action, + Thread.currentThread().getId(), + formatState(initialState, 32), + formatState(committedState, 32)), + new RuntimeException()); + } else { + logger.trace( + String.format( + "[%s][%s][%s][%s-%s]", + instance, + action, + Thread.currentThread().getId(), + formatState(initialState, 32), + formatState(committedState, 32))); + } + } + + static String formatState(long state, int size) { + final String defaultFormat = Long.toBinaryString(state); + final StringBuilder formatted = new StringBuilder(); + final int toPrepend = size - defaultFormat.length(); + for (int i = 0; i < size; i++) { + if (i != 0 && i % 4 == 0) { + formatted.append("_"); + } + if (i < toPrepend) { + formatted.append("0"); + } else { + formatted.append(defaultFormat.charAt(i - toPrepend)); + } + } + + formatted.insert(0, "0b"); + return formatted.toString(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java new file mode 100644 index 000000000..a99ef8a49 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseLinkedQueue.java @@ -0,0 +1,302 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; + +import java.util.AbstractQueue; +import java.util.Iterator; + +abstract class BaseLinkedQueuePad0 extends AbstractQueue implements MessagePassingQueue { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + // * drop 8b as object header acts as padding and is >= 8b * +} + +// $gen:ordered-fields +abstract class BaseLinkedQueueProducerNodeRef extends BaseLinkedQueuePad0 { + static final long P_NODE_OFFSET = + fieldOffset(BaseLinkedQueueProducerNodeRef.class, "producerNode"); + + private volatile LinkedQueueNode producerNode; + + final void spProducerNode(LinkedQueueNode newValue) { + UNSAFE.putObject(this, P_NODE_OFFSET, newValue); + } + + final void soProducerNode(LinkedQueueNode newValue) { + UNSAFE.putOrderedObject(this, P_NODE_OFFSET, newValue); + } + + final LinkedQueueNode lvProducerNode() { + return producerNode; + } + + final boolean casProducerNode(LinkedQueueNode expect, LinkedQueueNode newValue) { + return UNSAFE.compareAndSwapObject(this, P_NODE_OFFSET, expect, newValue); + } + + final LinkedQueueNode lpProducerNode() { + return producerNode; + } +} + +abstract class BaseLinkedQueuePad1 extends BaseLinkedQueueProducerNodeRef { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +// $gen:ordered-fields +abstract class BaseLinkedQueueConsumerNodeRef extends BaseLinkedQueuePad1 { + private static final long C_NODE_OFFSET = + fieldOffset(BaseLinkedQueueConsumerNodeRef.class, "consumerNode"); + + private LinkedQueueNode consumerNode; + + final void spConsumerNode(LinkedQueueNode newValue) { + consumerNode = newValue; + } + + @SuppressWarnings("unchecked") + final LinkedQueueNode lvConsumerNode() { + return (LinkedQueueNode) UNSAFE.getObjectVolatile(this, C_NODE_OFFSET); + } + + final LinkedQueueNode lpConsumerNode() { + return consumerNode; + } +} + +abstract class BaseLinkedQueuePad2 extends BaseLinkedQueueConsumerNodeRef { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +/** + * A base data structure for concurrent linked queues. For convenience also pulled in common single + * consumer methods since at this time there's no plan to implement MC. + */ +abstract class BaseLinkedQueue extends BaseLinkedQueuePad2 { + + @Override + public final Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + protected final LinkedQueueNode newNode() { + return new LinkedQueueNode(); + } + + protected final LinkedQueueNode newNode(E e) { + return new LinkedQueueNode(e); + } + + /** + * {@inheritDoc}
+ * + *

IMPLEMENTATION NOTES:
+ * This is an O(n) operation as we run through all the nodes and count them.
+ * The accuracy of the value returned by this method is subject to races with producer/consumer + * threads. In particular when racing with the consumer thread this method may under estimate the + * size.
+ * + * @see java.util.Queue#size() + */ + @Override + public final int size() { + // Read consumer first, this is important because if the producer is node is 'older' than the + // consumer + // the consumer may overtake it (consume past it) invalidating the 'snapshot' notion of size. + LinkedQueueNode chaserNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + int size = 0; + // must chase the nodes all the way to the producer node, but there's no need to count beyond + // expected head. + while (chaserNode != producerNode + && // don't go passed producer node + chaserNode != null + && // stop at last node + size < Integer.MAX_VALUE) // stop at max int + { + LinkedQueueNode next; + next = chaserNode.lvNext(); + // check if this node has been consumed, if so return what we have + if (next == chaserNode) { + return size; + } + chaserNode = next; + size++; + } + return size; + } + + /** + * {@inheritDoc}
+ * + *

IMPLEMENTATION NOTES:
+ * Queue is empty when producerNode is the same as consumerNode. An alternative implementation + * would be to observe the producerNode.value is null, which also means an empty queue because + * only the consumerNode.value is allowed to be null. + * + * @see MessagePassingQueue#isEmpty() + */ + @Override + public boolean isEmpty() { + LinkedQueueNode consumerNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + return consumerNode == producerNode; + } + + protected E getSingleConsumerNodeValue( + LinkedQueueNode currConsumerNode, LinkedQueueNode nextNode) { + // we have to null out the value because we are going to hang on to the node + final E nextValue = nextNode.getAndNullValue(); + + // Fix up the next ref of currConsumerNode to prevent promoted nodes from keeping new ones + // alive. + // We use a reference to self instead of null because null is already a meaningful value (the + // next of + // producer node is null). + currConsumerNode.soNext(currConsumerNode); + spConsumerNode(nextNode); + // currConsumerNode is now no longer referenced and can be collected + return nextValue; + } + + @Override + public E relaxedPoll() { + final LinkedQueueNode currConsumerNode = lpConsumerNode(); + final LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) { + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + return null; + } + + @Override + public E relaxedPeek() { + final LinkedQueueNode nextNode = lpConsumerNode().lvNext(); + if (nextNode != null) { + return nextNode.lpValue(); + } + return null; + } + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @Override + public int drain(Consumer c) { + long result = 0; // use long to force safepoint into loop below + int drained; + do { + drained = drain(c, 4096); + result += drained; + } while (drained == 4096 && result <= Integer.MAX_VALUE - 4096); + return (int) result; + } + + @Override + public int drain(Consumer c, int limit) { + LinkedQueueNode chaserNode = this.lpConsumerNode(); + for (int i = 0; i < limit; i++) { + final LinkedQueueNode nextNode = chaserNode.lvNext(); + + if (nextNode == null) { + return i; + } + // we have to null out the value because we are going to hang on to the node + final E nextValue = getSingleConsumerNodeValue(chaserNode, nextNode); + chaserNode = nextNode; + c.accept(nextValue); + } + return limit; + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + LinkedQueueNode chaserNode = this.lpConsumerNode(); + int idleCounter = 0; + while (exit.keepRunning()) { + for (int i = 0; i < 4096; i++) { + final LinkedQueueNode nextNode = chaserNode.lvNext(); + if (nextNode == null) { + idleCounter = wait.idle(idleCounter); + continue; + } + + idleCounter = 0; + // we have to null out the value because we are going to hang on to the node + final E nextValue = getSingleConsumerNodeValue(chaserNode, nextNode); + chaserNode = nextNode; + c.accept(nextValue); + } + } + } + + @Override + public int capacity() { + return UNBOUNDED_CAPACITY; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java new file mode 100644 index 000000000..cfad5ef71 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/BaseMpscLinkedArrayQueue.java @@ -0,0 +1,705 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.modifiedCalcCircularRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.allocateRefArray; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.calcCircularRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.calcRefElementOffset; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.lvRefElement; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.soRefElement; + +import io.rsocket.internal.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.NoSuchElementException; + +abstract class BaseMpscLinkedArrayQueuePad1 extends AbstractQueue implements IndexedQueue { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueProducerFields extends BaseMpscLinkedArrayQueuePad1 { + private static final long P_INDEX_OFFSET = + fieldOffset(BaseMpscLinkedArrayQueueProducerFields.class, "producerIndex"); + + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final void soProducerIndex(long newValue) { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + + final boolean casProducerIndex(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +abstract class BaseMpscLinkedArrayQueuePad2 extends BaseMpscLinkedArrayQueueProducerFields { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueConsumerFields extends BaseMpscLinkedArrayQueuePad2 { + private static final long C_INDEX_OFFSET = + fieldOffset(BaseMpscLinkedArrayQueueConsumerFields.class, "consumerIndex"); + + private volatile long consumerIndex; + protected long consumerMask; + protected E[] consumerBuffer; + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +abstract class BaseMpscLinkedArrayQueuePad3 extends BaseMpscLinkedArrayQueueConsumerFields { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueColdProducerFields + extends BaseMpscLinkedArrayQueuePad3 { + private static final long P_LIMIT_OFFSET = + fieldOffset(BaseMpscLinkedArrayQueueColdProducerFields.class, "producerLimit"); + + private volatile long producerLimit; + protected long producerMask; + protected E[] producerBuffer; + + final long lvProducerLimit() { + return producerLimit; + } + + final boolean casProducerLimit(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_LIMIT_OFFSET, expect, newValue); + } + + final void soProducerLimit(long newValue) { + UNSAFE.putOrderedLong(this, P_LIMIT_OFFSET, newValue); + } +} + +/** + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in + * linked chunks of the initial size. The queue grows only when the current buffer is full and + * elements are not copied on resize, instead a link to the new buffer is stored in the old buffer + * for the consumer to follow. + */ +abstract class BaseMpscLinkedArrayQueue extends BaseMpscLinkedArrayQueueColdProducerFields + implements MessagePassingQueue, QueueProgressIndicators { + // No post padding here, subclasses must add + private static final Object JUMP = new Object(); + private static final Object BUFFER_CONSUMED = new Object(); + private static final int CONTINUE_TO_P_INDEX_CAS = 0; + private static final int RETRY = 1; + private static final int QUEUE_FULL = 2; + private static final int QUEUE_RESIZE = 3; + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the + * chunk size. Must be 2 or more. + */ + public BaseMpscLinkedArrayQueue(final int initialCapacity) { + RangeUtil.checkGreaterThanOrEqual(initialCapacity, 2, "initialCapacity"); + + int p2capacity = Pow2.roundToPowerOfTwo(initialCapacity); + // leave lower bit of mask clear + long mask = (p2capacity - 1) << 1; + // need extra element to point at next array + E[] buffer = allocateRefArray(p2capacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + soProducerLimit(mask); // we know it's all empty to start with + } + + @Override + public int size() { + // NOTE: because indices are on even numbers we cannot use the size util. + + /* + * It is possible for a thread to be interrupted or reschedule between the read of the producer and + * consumer indices, therefore protection is required to ensure size is within valid range. In the + * event of concurrent polls/offers to this method the size is OVER estimated as we read consumer + * index BEFORE the producer index. + */ + long after = lvConsumerIndex(); + long size; + while (true) { + final long before = after; + final long currentProducerIndex = lvProducerIndex(); + after = lvConsumerIndex(); + if (before == after) { + size = ((currentProducerIndex - after) >> 1); + break; + } + } + // Long overflow is impossible, so size is always positive. Integer overflow is possible for the + // unbounded + // indexed queues. + if (size > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) size; + } + } + + @Override + public boolean isEmpty() { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there + // is + // nothing we can do to make this an exact method. + return (this.lvConsumerIndex() == this.lvProducerIndex()); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + + long mask; + E[] buffer; + long pIndex; + + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + + // mask/buffer may get changed by resizing -> only use for array access after successful CAS. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) - [mask/buffer] -> cas(pIndex) + + // assumption behind this optimization is that queue is almost always empty or near empty + if (producerLimit <= pIndex) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch (result) { + case CONTINUE_TO_P_INDEX_CAS: + break; + case RETRY: + continue; + case QUEUE_FULL: + return false; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, e, null); + return true; + } + } + + if (casProducerIndex(pIndex, pIndex + 2)) { + break; + } + } + // INDEX visible before ELEMENT + final long offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); // release element e + return true; + } + + /** + * {@inheritDoc} + * + *

This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E poll() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + if (index != lvProducerIndex()) { + // poll() == null iff queue is empty, null element is not strong enough indicator, so we + // must + // check the producer index. If the queue is indeed not empty we spin until element is + // visible. + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } else { + return null; + } + } + + if (e == JUMP) { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, index); + } + + soRefElement(buffer, offset, null); // release element null + soConsumerIndex(index + 2); // release cIndex + return (E) e; + } + + /** + * {@inheritDoc} + * + *

This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E peek() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); + if (e == null && index != lvProducerIndex()) { + // peek() == null iff queue is empty, null element is not strong enough indicator, so we must + // check the producer index. If the queue is indeed not empty we spin until element is + // visible. + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), index); + } + return (E) e; + } + + /** We do not inline resize into this method because we do not resize on fill. */ + private int offerSlowPath(long mask, long pIndex, long producerLimit) { + final long cIndex = lvConsumerIndex(); + long bufferCapacity = getCurrentBufferCapacity(mask); + + if (cIndex + bufferCapacity > pIndex) { + if (!casProducerLimit(producerLimit, cIndex + bufferCapacity)) { + // retry from top + return RETRY; + } else { + // continue to pIndex CAS + return CONTINUE_TO_P_INDEX_CAS; + } + } + // full and cannot grow + else if (availableInQueue(pIndex, cIndex) <= 0) { + // offer should return false; + return QUEUE_FULL; + } + // grab index for resize -> set lower bit + else if (casProducerIndex(pIndex, pIndex + 1)) { + // trigger a resize + return QUEUE_RESIZE; + } else { + // failed resize attempt, retry from top + return RETRY; + } + } + + /** @return available elements in queue * 2 */ + protected abstract long availableInQueue(long pIndex, long cIndex); + + @SuppressWarnings("unchecked") + private E[] nextBuffer(final E[] buffer, final long mask) { + final long offset = nextArrayOffset(mask); + final E[] nextBuffer = (E[]) lvRefElement(buffer, offset); + consumerBuffer = nextBuffer; + consumerMask = (length(nextBuffer) - 2) << 1; + soRefElement(buffer, offset, BUFFER_CONSUMED); + return nextBuffer; + } + + private static long nextArrayOffset(long mask) { + return modifiedCalcCircularRefElementOffset(mask + 2, Long.MAX_VALUE); + } + + private E newBufferPoll(E[] nextBuffer, long index) { + final long offset = modifiedCalcCircularRefElementOffset(index, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (n == null) { + throw new IllegalStateException("new buffer must have at least one element"); + } + soRefElement(nextBuffer, offset, null); + soConsumerIndex(index + 2); + return n; + } + + private E newBufferPeek(E[] nextBuffer, long index) { + final long offset = modifiedCalcCircularRefElementOffset(index, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (null == n) { + throw new IllegalStateException("new buffer must have at least one element"); + } + return n; + } + + @Override + public long currentProducerIndex() { + return lvProducerIndex() / 2; + } + + @Override + public long currentConsumerIndex() { + return lvConsumerIndex() / 2; + } + + @Override + public abstract int capacity(); + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPoll() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + return null; + } + if (e == JUMP) { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, index); + } + soRefElement(buffer, offset, null); + soConsumerIndex(index + 2); + return (E) e; + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPeek() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + Object e = lvRefElement(buffer, offset); + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), index); + } + return (E) e; + } + + @Override + public int fill(Supplier s) { + long result = + 0; // result is a long because we want to have a safepoint check at regular intervals + final int capacity = capacity(); + do { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= capacity); + return (int) result; + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) throw new IllegalArgumentException("supplier is null"); + if (limit < 0) throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) return 0; + + long mask; + E[] buffer; + long pIndex; + int claimedSlots; + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + + // NOTE: mask/buffer may get changed by resizing -> only use for array access after successful + // CAS. + // Only by virtue offloading them between the lvProducerIndex and a successful + // casProducerIndex are they + // safe to use. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) -> [mask/buffer] -> cas(pIndex) + + // we want 'limit' slots, but will settle for whatever is visible to 'producerLimit' + long batchIndex = + Math.min(producerLimit, pIndex + 2l * limit); // -> producerLimit >= batchIndex + + if (pIndex >= producerLimit) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch (result) { + case CONTINUE_TO_P_INDEX_CAS: + // offer slow path verifies only one slot ahead, we cannot rely on indication here + case RETRY: + continue; + case QUEUE_FULL: + return 0; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, null, s); + return 1; + } + } + + // claim limit slots at once + if (casProducerIndex(pIndex, batchIndex)) { + claimedSlots = (int) ((batchIndex - pIndex) / 2); + break; + } + } + + for (int i = 0; i < claimedSlots; i++) { + final long offset = modifiedCalcCircularRefElementOffset(pIndex + 2l * i, mask); + soRefElement(buffer, offset, s.get()); + } + return claimedSlots; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + @Override + public int drain(Consumer c) { + return drain(c, capacity()); + } + + @Override + public int drain(Consumer c, int limit) { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + /** + * Get an iterator for this queue. This method is thread safe. + * + *

The iterator provides a best-effort snapshot of the elements in the queue. The returned + * iterator is not guaranteed to return elements in queue order, and races with the consumer + * thread may cause gaps in the sequence of returned elements. Like {link #relaxedPoll}, the + * iterator may not immediately return newly inserted elements. + * + * @return The iterator. + */ + @Override + public Iterator iterator() { + return new WeakIterator(consumerBuffer, lvConsumerIndex(), lvProducerIndex()); + } + + private static class WeakIterator implements Iterator { + private final long pIndex; + private long nextIndex; + private E nextElement; + private E[] currentBuffer; + private int mask; + + WeakIterator(E[] currentBuffer, long cIndex, long pIndex) { + this.pIndex = pIndex >> 1; + this.nextIndex = cIndex >> 1; + setBuffer(currentBuffer); + nextElement = getNext(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + @Override + public boolean hasNext() { + return nextElement != null; + } + + @Override + public E next() { + final E e = nextElement; + if (e == null) { + throw new NoSuchElementException(); + } + nextElement = getNext(); + return e; + } + + private void setBuffer(E[] buffer) { + this.currentBuffer = buffer; + this.mask = length(buffer) - 2; + } + + private E getNext() { + while (nextIndex < pIndex) { + long index = nextIndex++; + E e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } + + // not null && not JUMP -> found next element + if (e != JUMP) { + return e; + } + + // need to jump to the next buffer + int nextBufferIndex = mask + 1; + Object nextBuffer = lvRefElement(currentBuffer, calcRefElementOffset(nextBufferIndex)); + + if (nextBuffer == BUFFER_CONSUMED || nextBuffer == null) { + // Consumer may have passed us, or the next buffer is not visible yet: drop out early + return null; + } + + setBuffer((E[]) nextBuffer); + // now with the new array retry the load, it can't be a JUMP, but we need to repeat same + // index + e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } else { + return e; + } + } + return null; + } + } + + private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s) { + assert (e != null && s == null) || (e == null || s != null); + int newBufferLength = getNextBufferSize(oldBuffer); + final E[] newBuffer; + try { + newBuffer = allocateRefArray(newBufferLength); + } catch (OutOfMemoryError oom) { + assert lvProducerIndex() == pIndex + 1; + soProducerIndex(pIndex); + throw oom; + } + + producerBuffer = newBuffer; + final int newMask = (newBufferLength - 2) << 1; + producerMask = newMask; + + final long offsetInOld = modifiedCalcCircularRefElementOffset(pIndex, oldMask); + final long offsetInNew = modifiedCalcCircularRefElementOffset(pIndex, newMask); + + soRefElement(newBuffer, offsetInNew, e == null ? s.get() : e); // element in new array + soRefElement(oldBuffer, nextArrayOffset(oldMask), newBuffer); // buffer linked + + // ASSERT code + final long cIndex = lvConsumerIndex(); + final long availableInQueue = availableInQueue(pIndex, cIndex); + RangeUtil.checkPositive(availableInQueue, "availableInQueue"); + + // Invalidate racing CASs + // We never set the limit beyond the bounds of a buffer + soProducerLimit(pIndex + Math.min(newMask, availableInQueue)); + + // make resize visible to the other producers + soProducerIndex(pIndex + 2); + + // INDEX visible before ELEMENT, consistent with consumer expectation + + // make resize visible to consumer + soRefElement(oldBuffer, offsetInOld, JUMP); + } + + /** @return next buffer size(inclusive of next array pointer) */ + protected abstract int getNextBufferSize(E[] buffer); + + /** @return current buffer capacity for elements (excluding next pointer and jump entry) * 2 */ + protected abstract long getCurrentBufferCapacity(long mask); +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java new file mode 100644 index 000000000..40116bbe1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/IndexedQueueSizeUtil.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +final class IndexedQueueSizeUtil { + public static int size(IndexedQueue iq) { + /* + * It is possible for a thread to be interrupted or reschedule between the read of the producer and + * consumer indices, therefore protection is required to ensure size is within valid range. In the + * event of concurrent polls/offers to this method the size is OVER estimated as we read consumer + * index BEFORE the producer index. + */ + long after = iq.lvConsumerIndex(); + long size; + while (true) { + final long before = after; + final long currentProducerIndex = iq.lvProducerIndex(); + after = iq.lvConsumerIndex(); + if (before == after) { + size = (currentProducerIndex - after); + break; + } + } + // Long overflow is impossible (), so size is always positive. Integer overflow is possible for + // the unbounded + // indexed queues. + if (size > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) size; + } + } + + public static boolean isEmpty(IndexedQueue iq) { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there + // is + // nothing we can do to make this an exact method. + return (iq.lvConsumerIndex() == iq.lvProducerIndex()); + } + + public interface IndexedQueue { + long lvConsumerIndex(); + + long lvProducerIndex(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java new file mode 100644 index 000000000..37651f351 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedArrayQueueUtil.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.REF_ARRAY_BASE; +import static io.rsocket.internal.jctools.queues.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; + +/** This is used for method substitution in the LinkedArray classes code generation. */ +final class LinkedArrayQueueUtil { + static int length(Object[] buf) { + return buf.length; + } + + /** + * This method assumes index is actually (index << 1) because lower bit is used for resize. This + * is compensated for by reducing the element shift. The computation is constant folded, so + * there's no cost. + */ + static long modifiedCalcCircularRefElementOffset(long index, long mask) { + return REF_ARRAY_BASE + ((index & mask) << (REF_ELEMENT_SHIFT - 1)); + } + + static long nextArrayOffset(Object[] curr) { + return REF_ARRAY_BASE + ((long) (length(curr) - 1) << REF_ELEMENT_SHIFT); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java new file mode 100644 index 000000000..72e78bb92 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/LinkedQueueNode.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; +import static io.rsocket.internal.jctools.queues.UnsafeAccess.fieldOffset; + +final class LinkedQueueNode { + private static final long NEXT_OFFSET = fieldOffset(LinkedQueueNode.class, "next"); + + private E value; + private volatile LinkedQueueNode next; + + LinkedQueueNode() { + this(null); + } + + LinkedQueueNode(E val) { + spValue(val); + } + + /** + * Gets the current value and nulls out the reference to it from this node. + * + * @return value + */ + public E getAndNullValue() { + E temp = lpValue(); + spValue(null); + return temp; + } + + public E lpValue() { + return value; + } + + public void spValue(E newValue) { + value = newValue; + } + + public void soNext(LinkedQueueNode n) { + UNSAFE.putOrderedObject(this, NEXT_OFFSET, n); + } + + public void spNext(LinkedQueueNode n) { + UNSAFE.putObject(this, NEXT_OFFSET, n); + } + + public LinkedQueueNode lvNext() { + return next; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java new file mode 100644 index 000000000..7a0fa901f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueue.java @@ -0,0 +1,339 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import java.util.Queue; + +/** + * Message passing queues are intended for concurrent method passing. A subset of {@link Queue} + * methods are provided with the same semantics, while further functionality which accomodates the + * concurrent usecase is also on offer. + * + *

Message passing queues provide happens before semantics to messages passed through, namely + * that writes made by the producer before offering the message are visible to the consuming thread + * after the message has been polled out of the queue. + * + * @param the event/message type + */ +public interface MessagePassingQueue { + int UNBOUNDED_CAPACITY = -1; + + interface Supplier { + /** + * This method will return the next value to be written to the queue. As such the queue + * implementations are commited to insert the value once the call is made. + * + *

Users should be aware that underlying queue implementations may upfront claim parts of the + * queue for batch operations and this will effect the view on the queue from the supplier + * method. In particular size and any offer methods may take the view that the full batch has + * already happened. + * + *

WARNING: this method is assumed to never throw. Breaking this assumption can lead + * to a broken queue. + * + *

WARNING: this method is assumed to never return {@code null}. Breaking this + * assumption can lead to a broken queue. + * + * @return new element, NEVER {@code null} + */ + T get(); + } + + interface Consumer { + /** + * This method will process an element already removed from the queue. This method is expected + * to never throw an exception. + * + *

Users should be aware that underlying queue implementations may upfront claim parts of the + * queue for batch operations and this will effect the view on the queue from the accept method. + * In particular size and any poll/peek methods may take the view that the full batch has + * already happened. + * + *

WARNING: this method is assumed to never throw. Breaking this assumption can lead + * to a broken queue. + * + * @param e not {@code null} + */ + void accept(T e); + } + + interface WaitStrategy { + /** + * This method can implement static or dynamic backoff. Dynamic backoff will rely on the counter + * for estimating how long the caller has been idling. The expected usage is: + * + *

+ * + *

+     * 
+     * int ic = 0;
+     * while(true) {
+     *   if(!isGodotArrived()) {
+     *     ic = w.idle(ic);
+     *     continue;
+     *   }
+     *   ic = 0;
+     *   // party with Godot until he goes again
+     * }
+     * 
+     * 
+ * + * @param idleCounter idle calls counter, managed by the idle method until reset + * @return new counter value to be used on subsequent idle cycle + */ + int idle(int idleCounter); + } + + interface ExitCondition { + + /** + * This method should be implemented such that the flag read or determination cannot be hoisted + * out of a loop which notmally means a volatile load, but with JDK9 VarHandles may mean + * getOpaque. + * + * @return true as long as we should keep running + */ + boolean keepRunning(); + } + + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation and + * according to the {@link Queue#offer(Object)} interface. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false iff full + */ + boolean offer(T e); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation + * and according to the {@link Queue#poll()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ + T poll(); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation + * and according to the {@link Queue#peek()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ + T peek(); + + /** + * This method's accuracy is subject to concurrent modifications happening as the size is + * estimated and as such is a best effort rather than absolute value. For some implementations + * this method may be O(n) rather than O(1). + * + * @return number of messages in the queue, between 0 and {@link Integer#MAX_VALUE} but less or + * equals to capacity (if bounded). + */ + int size(); + + /** + * Removes all items from the queue. Called from the consumer thread subject to the restrictions + * appropriate to the implementation and according to the {@link Queue#clear()} interface. + */ + void clear(); + + /** + * This method's accuracy is subject to concurrent modifications happening as the observation is + * carried out. + * + * @return true if empty, false otherwise + */ + boolean isEmpty(); + + /** + * @return the capacity of this queue or {@link MessagePassingQueue#UNBOUNDED_CAPACITY} if not + * bounded + */ + int capacity(); + + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation. As + * opposed to {@link Queue#offer(Object)} this method may return false without the queue being + * full. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false if unable to offer + */ + boolean relaxedOffer(T e); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. + * As opposed to {@link Queue#poll()} this method may return {@code null} without the queue being + * empty. + * + * @return a message from the queue if one is available, {@code null} if unable to poll + */ + T relaxedPoll(); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. + * As opposed to {@link Queue#peek()} this method may return {@code null} without the queue being + * empty. + * + * @return a message from the queue if one is available, {@code null} if unable to peek + */ + T relaxedPeek(); + + /** + * Remove up to limit elements from the queue and hand to consume. This should be + * semantically similar to: + * + *

+ * + *

{@code
+   * M m;
+   * int i = 0;
+   * for(;i < limit && (m = relaxedPoll()) != null; i++){
+   *   c.accept(m);
+   * }
+   * return i;
+   * }
+ * + *

There's no strong commitment to the queue being empty at the end of a drain. Called from a + * consumer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + * @throws IllegalArgumentException if limit is negative + */ + int drain(Consumer c, int limit); + + /** + * Stuff the queue with up to limit elements from the supplier. Semantically similar to: + * + *

+ * + *

{@code
+   * for(int i=0; i < limit && relaxedOffer(s.get()); i++);
+   * }
+ * + *

There's no strong commitment to the queue being full at the end of a fill. Called from a + * producer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + * @throws IllegalArgumentException if limit is negative + */ + int fill(Supplier s, int limit); + + /** + * Remove all available item from the queue and hand to consume. This should be semantically + * similar to: + * + *

+   * M m;
+   * while((m = relaxedPoll()) != null){
+   * c.accept(m);
+   * }
+   * 
+ * + * There's no strong commitment to the queue being empty at the end of a drain. Called from a + * consumer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + */ + int drain(Consumer c); + + /** + * Stuff the queue with elements from the supplier. Semantically similar to: + * + *

+   * while(relaxedOffer(s.get());
+   * 
+ * + * There's no strong commitment to the queue being full at the end of a fill. Called from a + * producer thread subject to the restrictions appropriate to the implementation. + * + *

Unbounded queues will fill up the queue with a fixed amount rather than fill up to oblivion. + * + *

WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + */ + int fill(Supplier s); + + /** + * Remove elements from the queue and hand to consume forever. Semantically similar to: + * + *

+ * + *

+   *  int idleCounter = 0;
+   *  while (exit.keepRunning()) {
+   *      E e = relaxedPoll();
+   *      if(e==null){
+   *          idleCounter = wait.idle(idleCounter);
+   *          continue;
+   *      }
+   *      idleCounter = 0;
+   *      c.accept(e);
+   *  }
+   * 
+ * + *

Called from a consumer thread subject to the restrictions appropriate to the implementation. + * + *

WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make + * sure you have read and understood these before using this method. + * + * @throws IllegalArgumentException c OR wait OR exit are {@code null} + */ + void drain(Consumer c, WaitStrategy wait, ExitCondition exit); + + /** + * Stuff the queue with elements from the supplier forever. Semantically similar to: + * + *

+ * + *

+   * 
+   *  int idleCounter = 0;
+   *  while (exit.keepRunning()) {
+   *      E e = s.get();
+   *      while (!relaxedOffer(e)) {
+   *          idleCounter = wait.idle(idleCounter);
+   *          continue;
+   *      }
+   *      idleCounter = 0;
+   *  }
+   * 
+   * 
+ * + *

Called from a producer thread subject to the restrictions appropriate to the implementation. + * The main difference being that implementors MUST assure room in the queue is available BEFORE + * calling {@link Supplier#get}. + * + *

WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure + * you have read and understood these before using this method. + * + * @throws IllegalArgumentException s OR wait OR exit are {@code null} + */ + void fill(Supplier s, WaitStrategy wait, ExitCondition exit); +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java new file mode 100644 index 000000000..cb03364d8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MessagePassingQueueUtil.java @@ -0,0 +1,100 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal.jctools.queues; + +import io.rsocket.internal.jctools.queues.MessagePassingQueue.Consumer; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.ExitCondition; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.Supplier; +import io.rsocket.internal.jctools.queues.MessagePassingQueue.WaitStrategy; + +final class MessagePassingQueueUtil { + public static int drain(MessagePassingQueue queue, Consumer c, int limit) { + if (null == c) throw new IllegalArgumentException("c is null"); + if (limit < 0) throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) return 0; + E e; + int i = 0; + for (; i < limit && (e = queue.relaxedPoll()) != null; i++) { + c.accept(e); + } + return i; + } + + public static int drain(MessagePassingQueue queue, Consumer c) { + if (null == c) throw new IllegalArgumentException("c is null"); + E e; + int i = 0; + while ((e = queue.relaxedPoll()) != null) { + i++; + c.accept(e); + } + return i; + } + + public static void drain( + MessagePassingQueue queue, Consumer c, WaitStrategy wait, ExitCondition exit) { + if (null == c) throw new IllegalArgumentException("c is null"); + if (null == wait) throw new IllegalArgumentException("wait is null"); + if (null == exit) throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) { + final E e = queue.relaxedPoll(); + if (e == null) { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + c.accept(e); + } + } + + public static void fill( + MessagePassingQueue q, Supplier s, WaitStrategy wait, ExitCondition exit) { + if (null == wait) throw new IllegalArgumentException("waiter is null"); + if (null == exit) throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) { + if (q.fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + } + } + + public static int fillBounded(MessagePassingQueue q, Supplier s) { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, q.capacity()); + } + + public static int fillInBatchesToLimit( + MessagePassingQueue q, Supplier s, int batch, int limit) { + long result = + 0; // result is a long because we want to have a safepoint check at regular intervals + do { + final int filled = q.fill(s, batch); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= limit); + return (int) result; + } + + public static int fillUnbounded(MessagePassingQueue q, Supplier s) { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, 4096); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java new file mode 100644 index 000000000..179070be4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/MpscUnboundedArrayQueue.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.LinkedArrayQueueUtil.length; +import static io.rsocket.internal.jctools.queues.MessagePassingQueueUtil.fillUnbounded; + +/** + * An MPSC array queue which starts at initialCapacity and grows indefinitely in linked + * chunks of the initial size. The queue grows only when the current chunk is full and elements are + * not copied on resize, instead a link to the new chunk is stored in the old chunk for the consumer + * to follow. + */ +public class MpscUnboundedArrayQueue extends BaseMpscLinkedArrayQueue { + byte b000, b001, b002, b003, b004, b005, b006, b007; // 8b + byte b010, b011, b012, b013, b014, b015, b016, b017; // 16b + byte b020, b021, b022, b023, b024, b025, b026, b027; // 24b + byte b030, b031, b032, b033, b034, b035, b036, b037; // 32b + byte b040, b041, b042, b043, b044, b045, b046, b047; // 40b + byte b050, b051, b052, b053, b054, b055, b056, b057; // 48b + byte b060, b061, b062, b063, b064, b065, b066, b067; // 56b + byte b070, b071, b072, b073, b074, b075, b076, b077; // 64b + byte b100, b101, b102, b103, b104, b105, b106, b107; // 72b + byte b110, b111, b112, b113, b114, b115, b116, b117; // 80b + byte b120, b121, b122, b123, b124, b125, b126, b127; // 88b + byte b130, b131, b132, b133, b134, b135, b136, b137; // 96b + byte b140, b141, b142, b143, b144, b145, b146, b147; // 104b + byte b150, b151, b152, b153, b154, b155, b156, b157; // 112b + byte b160, b161, b162, b163, b164, b165, b166, b167; // 120b + byte b170, b171, b172, b173, b174, b175, b176, b177; // 128b + + public MpscUnboundedArrayQueue(int chunkSize) { + super(chunkSize); + } + + @Override + protected long availableInQueue(long pIndex, long cIndex) { + return Integer.MAX_VALUE; + } + + @Override + public int capacity() { + return MessagePassingQueue.UNBOUNDED_CAPACITY; + } + + @Override + public int drain(Consumer c) { + return drain(c, 4096); + } + + @Override + public int fill(Supplier s) { + return fillUnbounded(this, s); + } + + @Override + protected int getNextBufferSize(E[] buffer) { + return length(buffer); + } + + @Override + protected long getCurrentBufferCapacity(long mask) { + return mask; + } +} diff --git a/src/main/java/io/reactivesocket/exceptions/CancelException.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java similarity index 59% rename from src/main/java/io/reactivesocket/exceptions/CancelException.java rename to rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java index 15f8ef13d..f037857e8 100644 --- a/src/main/java/io/reactivesocket/exceptions/CancelException.java +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/PortableJvmInfo.java @@ -1,6 +1,4 @@ -/** - * Copyright 2015 Netflix, Inc. - * +/* * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -13,15 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket.exceptions; - -public class CancelException extends Throwable { - public CancelException(String message) { - super(message); - } +package io.rsocket.internal.jctools.queues; - @Override - public synchronized Throwable fillInStackTrace() { - return this; - } +/** JVM Information that is standard and available on all JVMs (i.e. does not use unsafe) */ +interface PortableJvmInfo { + int CACHE_LINE_SIZE = Integer.getInteger("jctools.cacheLineSize", 64); + int CPUs = Runtime.getRuntime().availableProcessors(); + int RECOMENDED_OFFER_BATCH = CPUs * 4; + int RECOMENDED_POLL_BATCH = CPUs * 4; } diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java new file mode 100644 index 000000000..282a22f02 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/Pow2.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +/** Power of 2 utility functions. */ +final class Pow2 { + public static final int MAX_POW2 = 1 << 30; + + /** + * @param value from which next positive power of two will be found. + * @return the next positive power of 2, this value if it is a power of 2. Negative values are + * mapped to 1. + * @throws IllegalArgumentException is value is more than MAX_POW2 or less than 0 + */ + public static int roundToPowerOfTwo(final int value) { + if (value > MAX_POW2) { + throw new IllegalArgumentException( + "There is no larger power of 2 int for value:" + value + " since it exceeds 2^31."); + } + if (value < 0) { + throw new IllegalArgumentException("Given value:" + value + ". Expecting value >= 0."); + } + final int nextPow2 = 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); + return nextPow2; + } + + /** + * @param value to be tested to see if it is a power of two. + * @return true if the value is a power of 2 otherwise false. + */ + public static boolean isPowerOfTwo(final int value) { + return (value & (value - 1)) == 0; + } + + /** + * Align a value to the next multiple up of alignment. If the value equals an alignment multiple + * then it is returned unchanged. + * + * @param value to be aligned up. + * @param alignment to be used, must be a power of 2. + * @return the value aligned to the next boundary. + */ + public static long align(final long value, final int alignment) { + if (!isPowerOfTwo(alignment)) { + throw new IllegalArgumentException("alignment must be a power of 2:" + alignment); + } + return (value + (alignment - 1)) & ~(alignment - 1); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/QueueProgressIndicators.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/QueueProgressIndicators.java new file mode 100644 index 000000000..6418cc947 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/QueueProgressIndicators.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +/** + * This interface is provided for monitoring purposes only and is only available on queues where it + * is easy to provide it. The producer/consumer progress indicators usually correspond with the + * number of elements offered/polled, but they are not guaranteed to maintain that semantic. + * + * @author nitsanw + */ +public interface QueueProgressIndicators { + + /** + * This method has no concurrent visibility semantics. The value returned may be negative. Under + * normal circumstances 2 consecutive calls to this method can offer an idea of progress made by + * producer threads by subtracting the 2 results though in extreme cases (if producers have + * progressed by more than 2^64) this may also fail.
+ * This value will normally indicate number of elements passed into the queue, but may under some + * circumstances be a derivative of that figure. This method should not be used to derive size or + * emptiness. + * + * @return the current value of the producer progress index + */ + long currentProducerIndex(); + + /** + * This method has no concurrent visibility semantics. The value returned may be negative. Under + * normal circumstances 2 consecutive calls to this method can offer an idea of progress made by + * consumer threads by subtracting the 2 results though in extreme cases (if consumers have + * progressed by more than 2^64) this may also fail.
+ * This value will normally indicate number of elements taken out of the queue, but may under some + * circumstances be a derivative of that figure. This method should not be used to derive size or + * emptiness. + * + * @return the current value of the consumer progress index + */ + long currentConsumerIndex(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java new file mode 100644 index 000000000..3adcb2f3c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/RangeUtil.java @@ -0,0 +1,56 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +final class RangeUtil { + public static long checkPositive(long n, String name) { + if (n <= 0) { + throw new IllegalArgumentException(name + ": " + n + " (expected: > 0)"); + } + + return n; + } + + public static int checkPositiveOrZero(int n, String name) { + if (n < 0) { + throw new IllegalArgumentException(name + ": " + n + " (expected: >= 0)"); + } + + return n; + } + + public static int checkLessThan(int n, int expected, String name) { + if (n >= expected) { + throw new IllegalArgumentException(name + ": " + n + " (expected: < " + expected + ')'); + } + + return n; + } + + public static int checkLessThanOrEqual(int n, long expected, String name) { + if (n > expected) { + throw new IllegalArgumentException(name + ": " + n + " (expected: <= " + expected + ')'); + } + + return n; + } + + public static int checkGreaterThanOrEqual(int n, int expected, String name) { + if (n < expected) { + throw new IllegalArgumentException(name + ": " + n + " (expected: >= " + expected + ')'); + } + + return n; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java new file mode 100644 index 000000000..c99aeb689 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeAccess.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import sun.misc.Unsafe; + +/** + * Why should we resort to using Unsafe?
+ * + *

    + *
  1. To construct class fields which allow volatile/ordered/plain access: This requirement is + * covered by {@link AtomicReferenceFieldUpdater} and similar but their performance is + * arguably worse than the DIY approach (depending on JVM version) while Unsafe + * intrinsification is a far lesser challenge for JIT compilers. + *
  2. To construct flavors of {@link AtomicReferenceArray}. + *
  3. Other use cases exist but are not present in this library yet. + *
+ * + * @author nitsanw + */ +class UnsafeAccess { + public static final boolean SUPPORTS_GET_AND_SET_REF; + public static final boolean SUPPORTS_GET_AND_ADD_LONG; + public static final Unsafe UNSAFE; + + static { + UNSAFE = getUnsafe(); + SUPPORTS_GET_AND_SET_REF = hasGetAndSetSupport(); + SUPPORTS_GET_AND_ADD_LONG = hasGetAndAddLongSupport(); + } + + private static Unsafe getUnsafe() { + Unsafe instance; + try { + final Field field = Unsafe.class.getDeclaredField("theUnsafe"); + field.setAccessible(true); + instance = (Unsafe) field.get(null); + } catch (Exception ignored) { + // Some platforms, notably Android, might not have a sun.misc.Unsafe implementation with a + // private + // `theUnsafe` static instance. In this case we can try to call the default constructor, which + // is sufficient + // for Android usage. + try { + Constructor c = Unsafe.class.getDeclaredConstructor(); + c.setAccessible(true); + instance = c.newInstance(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + return instance; + } + + private static boolean hasGetAndSetSupport() { + try { + Unsafe.class.getMethod("getAndSetObject", Object.class, Long.TYPE, Object.class); + return true; + } catch (Exception ignored) { + } + return false; + } + + private static boolean hasGetAndAddLongSupport() { + try { + Unsafe.class.getMethod("getAndAddLong", Object.class, Long.TYPE, Long.TYPE); + return true; + } catch (Exception ignored) { + } + return false; + } + + public static long fieldOffset(Class clz, String fieldName) throws RuntimeException { + try { + return UNSAFE.objectFieldOffset(clz.getDeclaredField(fieldName)); + } catch (NoSuchFieldException e) { + throw new RuntimeException(e); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java new file mode 100644 index 000000000..c734a9914 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/jctools/queues/UnsafeRefArrayAccess.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.internal.jctools.queues; + +import static io.rsocket.internal.jctools.queues.UnsafeAccess.UNSAFE; + +final class UnsafeRefArrayAccess { + public static final long REF_ARRAY_BASE; + public static final int REF_ELEMENT_SHIFT; + + static { + final int scale = UNSAFE.arrayIndexScale(Object[].class); + if (4 == scale) { + REF_ELEMENT_SHIFT = 2; + } else if (8 == scale) { + REF_ELEMENT_SHIFT = 3; + } else { + throw new IllegalStateException("Unknown pointer size: " + scale); + } + REF_ARRAY_BASE = UNSAFE.arrayBaseOffset(Object[].class); + } + + /** + * A plain store (no ordering/fences) of an element to a given offset + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} + * @param e an orderly kitty + */ + public static void spRefElement(E[] buffer, long offset, E e) { + UNSAFE.putObject(buffer, offset, e); + } + + /** + * An ordered store of an element to a given offset + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcCircularRefElementOffset} + * @param e an orderly kitty + */ + public static void soRefElement(E[] buffer, long offset, E e) { + UNSAFE.putOrderedObject(buffer, offset, e); + } + + /** + * A plain load (no ordering/fences) of an element from a given offset. + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} + * @return the element at the offset + */ + @SuppressWarnings("unchecked") + public static E lpRefElement(E[] buffer, long offset) { + return (E) UNSAFE.getObject(buffer, offset); + } + + /** + * A volatile load of an element from a given offset. + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} + * @return the element at the offset + */ + @SuppressWarnings("unchecked") + public static E lvRefElement(E[] buffer, long offset) { + return (E) UNSAFE.getObjectVolatile(buffer, offset); + } + + /** + * @param index desirable element index + * @return the offset in bytes within the array for a given index + */ + public static long calcRefElementOffset(long index) { + return REF_ARRAY_BASE + (index << REF_ELEMENT_SHIFT); + } + + /** + * Note: circular arrays are assumed a power of 2 in length and the `mask` is (length - 1). + * + * @param index desirable element index + * @param mask (length - 1) + * @return the offset in bytes within the circular array for a given index + */ + public static long calcCircularRefElementOffset(long index, long mask) { + return REF_ARRAY_BASE + ((index & mask) << REF_ELEMENT_SHIFT); + } + + /** This makes for an easier time generating the atomic queues, and removes some warnings. */ + @SuppressWarnings("unchecked") + public static E[] allocateRefArray(int capacity) { + return (E[]) new Object[capacity]; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/internal/package-info.java b/rsocket-core/src/main/java/io/rsocket/internal/package-info.java new file mode 100644 index 000000000..07ddfab41 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/internal/package-info.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Internal package and must not be used outside this project. There are no guarantees for + * API compatibility. + */ +@NonNullApi +package io.rsocket.internal; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveFramesAcceptor.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveFramesAcceptor.java new file mode 100644 index 000000000..8fb918dc6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveFramesAcceptor.java @@ -0,0 +1,9 @@ +package io.rsocket.keepalive; + +import io.netty.buffer.ByteBuf; +import reactor.core.Disposable; + +public interface KeepAliveFramesAcceptor extends Disposable { + + void receive(ByteBuf keepAliveFrame); +} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java new file mode 100644 index 000000000..4fd7a772d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveHandler.java @@ -0,0 +1,60 @@ +package io.rsocket.keepalive; + +import io.netty.buffer.ByteBuf; +import io.rsocket.keepalive.KeepAliveSupport.KeepAlive; +import io.rsocket.resume.RSocketSession; +import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumeStateHolder; +import java.util.function.Consumer; + +public interface KeepAliveHandler { + + KeepAliveFramesAcceptor start( + KeepAliveSupport keepAliveSupport, + Consumer onFrameSent, + Consumer onTimeout); + + class DefaultKeepAliveHandler implements KeepAliveHandler { + @Override + public KeepAliveFramesAcceptor start( + KeepAliveSupport keepAliveSupport, + Consumer onSendKeepAliveFrame, + Consumer onTimeout) { + return keepAliveSupport + .onSendKeepAliveFrame(onSendKeepAliveFrame) + .onTimeout(onTimeout) + .start(); + } + } + + class ResumableKeepAliveHandler implements KeepAliveHandler { + + private final ResumableDuplexConnection resumableDuplexConnection; + private final RSocketSession rSocketSession; + private final ResumeStateHolder resumeStateHolder; + + public ResumableKeepAliveHandler( + ResumableDuplexConnection resumableDuplexConnection, + RSocketSession rSocketSession, + ResumeStateHolder resumeStateHolder) { + this.resumableDuplexConnection = resumableDuplexConnection; + this.rSocketSession = rSocketSession; + this.resumeStateHolder = resumeStateHolder; + } + + @Override + public KeepAliveFramesAcceptor start( + KeepAliveSupport keepAliveSupport, + Consumer onSendKeepAliveFrame, + Consumer onTimeout) { + + rSocketSession.setKeepAliveSupport(keepAliveSupport); + + return keepAliveSupport + .resumeState(resumeStateHolder) + .onSendKeepAliveFrame(onSendKeepAliveFrame) + .onTimeout(keepAlive -> resumableDuplexConnection.disconnect()) + .start(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java new file mode 100644 index 000000000..4fd18d041 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/KeepAliveSupport.java @@ -0,0 +1,201 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.keepalive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.resume.ResumeStateHolder; +import java.time.Duration; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.function.Consumer; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public abstract class KeepAliveSupport implements KeepAliveFramesAcceptor { + + final ByteBufAllocator allocator; + final Scheduler scheduler; + final Duration keepAliveInterval; + final Duration keepAliveTimeout; + final long keepAliveTimeoutMillis; + + volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(KeepAliveSupport.class, "state"); + + static final int STOPPED_STATE = 0; + static final int STARTING_STATE = 1; + static final int STARTED_STATE = 2; + static final int DISPOSED_STATE = -1; + + volatile Consumer onTimeout; + volatile Consumer onFrameSent; + + Disposable ticksDisposable; + + volatile ResumeStateHolder resumeStateHolder; + volatile long lastReceivedMillis; + + private KeepAliveSupport( + ByteBufAllocator allocator, int keepAliveInterval, int keepAliveTimeout) { + this.allocator = allocator; + this.scheduler = Schedulers.parallel(); + this.keepAliveInterval = Duration.ofMillis(keepAliveInterval); + this.keepAliveTimeout = Duration.ofMillis(keepAliveTimeout); + this.keepAliveTimeoutMillis = keepAliveTimeout; + } + + public KeepAliveSupport start() { + if (this.state == STOPPED_STATE && STATE.compareAndSet(this, STOPPED_STATE, STARTING_STATE)) { + this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); + + final Disposable disposable = + Flux.interval(keepAliveInterval, scheduler).subscribe(v -> onIntervalTick()); + this.ticksDisposable = disposable; + + if (this.state != STARTING_STATE + || !STATE.compareAndSet(this, STARTING_STATE, STARTED_STATE)) { + disposable.dispose(); + } + } + return this; + } + + public void stop() { + terminate(STOPPED_STATE); + } + + @Override + public void receive(ByteBuf keepAliveFrame) { + this.lastReceivedMillis = scheduler.now(TimeUnit.MILLISECONDS); + if (resumeStateHolder != null) { + final long remoteLastReceivedPos = KeepAliveFrameCodec.lastPosition(keepAliveFrame); + resumeStateHolder.onImpliedPosition(remoteLastReceivedPos); + } + if (KeepAliveFrameCodec.respondFlag(keepAliveFrame)) { + long localLastReceivedPos = localLastReceivedPosition(); + send( + KeepAliveFrameCodec.encode( + allocator, + false, + localLastReceivedPos, + KeepAliveFrameCodec.data(keepAliveFrame).retain())); + } + } + + public KeepAliveSupport resumeState(ResumeStateHolder resumeStateHolder) { + this.resumeStateHolder = resumeStateHolder; + return this; + } + + public KeepAliveSupport onSendKeepAliveFrame(Consumer onFrameSent) { + this.onFrameSent = onFrameSent; + return this; + } + + public KeepAliveSupport onTimeout(Consumer onTimeout) { + this.onTimeout = onTimeout; + return this; + } + + @Override + public void dispose() { + terminate(DISPOSED_STATE); + } + + @Override + public boolean isDisposed() { + return ticksDisposable.isDisposed(); + } + + abstract void onIntervalTick(); + + void send(ByteBuf frame) { + if (onFrameSent != null) { + onFrameSent.accept(frame); + } + } + + void tryTimeout() { + long now = scheduler.now(TimeUnit.MILLISECONDS); + if (now - lastReceivedMillis >= keepAliveTimeoutMillis) { + if (onTimeout != null) { + onTimeout.accept(new KeepAlive(keepAliveInterval, keepAliveTimeout)); + } + stop(); + } + } + + void terminate(int terminationState) { + for (; ; ) { + final int state = this.state; + + if (state == STOPPED_STATE || state == DISPOSED_STATE) { + return; + } + + final Disposable disposable = this.ticksDisposable; + if (STATE.compareAndSet(this, state, terminationState)) { + disposable.dispose(); + return; + } + } + } + + long localLastReceivedPosition() { + return resumeStateHolder != null ? resumeStateHolder.impliedPosition() : 0; + } + + public static final class ClientKeepAliveSupport extends KeepAliveSupport { + + public ClientKeepAliveSupport( + ByteBufAllocator allocator, int keepAliveInterval, int keepAliveTimeout) { + super(allocator, keepAliveInterval, keepAliveTimeout); + } + + @Override + void onIntervalTick() { + tryTimeout(); + send( + KeepAliveFrameCodec.encode( + allocator, true, localLastReceivedPosition(), Unpooled.EMPTY_BUFFER)); + } + } + + public static final class KeepAlive { + private final Duration tickPeriod; + private final Duration timeoutMillis; + + public KeepAlive(Duration tickPeriod, Duration timeoutMillis) { + this.tickPeriod = tickPeriod; + this.timeoutMillis = timeoutMillis; + } + + public Duration getTickPeriod() { + return tickPeriod; + } + + public Duration getTimeout() { + return timeoutMillis; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java b/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java new file mode 100644 index 000000000..d94a93cad --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/keepalive/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Support classes for sending and keeping track of KEEPALIVE frames from the remote. */ +@NonNullApi +package io.rsocket.keepalive; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/lease/Lease.java b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java new file mode 100644 index 000000000..9e76d176d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/Lease.java @@ -0,0 +1,109 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.lease; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import java.time.Duration; +import reactor.util.annotation.Nullable; + +/** A contract for RSocket lease, which is sent by a request acceptor and is time bound. */ +public final class Lease { + + public static Lease create( + Duration timeToLive, int numberOfRequests, @Nullable ByteBuf metadata) { + return new Lease(timeToLive, numberOfRequests, metadata); + } + + public static Lease create(Duration timeToLive, int numberOfRequests) { + return create(timeToLive, numberOfRequests, Unpooled.EMPTY_BUFFER); + } + + public static Lease unbounded() { + return unbounded(null); + } + + public static Lease unbounded(@Nullable ByteBuf metadata) { + return create(Duration.ofMillis(Integer.MAX_VALUE), Integer.MAX_VALUE, metadata); + } + + public static Lease empty() { + return create(Duration.ZERO, 0); + } + + final int timeToLiveMillis; + final int numberOfRequests; + final ByteBuf metadata; + final long expirationTime; + + Lease(Duration timeToLive, int numberOfRequests, @Nullable ByteBuf metadata) { + this.numberOfRequests = numberOfRequests; + this.timeToLiveMillis = (int) Math.min(timeToLive.toMillis(), Integer.MAX_VALUE); + this.metadata = metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + this.expirationTime = + timeToLive.isZero() ? 0 : System.currentTimeMillis() + timeToLive.toMillis(); + } + + /** + * Number of requests allowed by this lease. + * + * @return The number of requests allowed by this lease. + */ + public int numberOfRequests() { + return numberOfRequests; + } + + /** + * Time to live for the given lease + * + * @return relative duration in milliseconds + */ + public int timeToLiveInMillis() { + return this.timeToLiveMillis; + } + + /** + * Absolute time since epoch at which this lease will expire. + * + * @return Absolute time since epoch at which this lease will expire. + */ + public long expirationTime() { + return expirationTime; + } + + /** + * Metadata for the lease. + * + * @return Metadata for the lease. + */ + @Nullable + public ByteBuf metadata() { + return metadata; + } + + @Override + public String toString() { + return "Lease{" + + "timeToLiveMillis=" + + timeToLiveMillis + + ", numberOfRequests=" + + numberOfRequests + + ", expirationTime=" + + expirationTime + + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java b/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java new file mode 100644 index 000000000..48bd38494 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/LeaseSender.java @@ -0,0 +1,8 @@ +package io.rsocket.lease; + +import reactor.core.publisher.Flux; + +public interface LeaseSender { + + Flux send(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java new file mode 100644 index 000000000..84af91b1b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/MissingLeaseException.java @@ -0,0 +1,31 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.lease; + +import io.rsocket.exceptions.RejectedException; + +public class MissingLeaseException extends RejectedException { + private static final long serialVersionUID = -6169748673403858959L; + + public MissingLeaseException(String message) { + super(message); + } + + @Override + public synchronized Throwable fillInStackTrace() { + return this; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java b/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java new file mode 100644 index 000000000..3e6f68321 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/TrackingLeaseSender.java @@ -0,0 +1,5 @@ +package io.rsocket.lease; + +import io.rsocket.plugins.RequestInterceptor; + +public interface TrackingLeaseSender extends LeaseSender, RequestInterceptor {} diff --git a/rsocket-core/src/main/java/io/rsocket/lease/package-info.java b/rsocket-core/src/main/java/io/rsocket/lease/package-info.java new file mode 100644 index 000000000..342ab27f7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/lease/package-info.java @@ -0,0 +1,27 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains support classes for the Lease feature of the RSocket protocol. + * + * @see Resuming + * Operation + */ +@NonNullApi +package io.rsocket.lease; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java new file mode 100644 index 000000000..fdbbeb25d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/BaseWeightedStats.java @@ -0,0 +1,235 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.util.Clock; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Implementation of {@link WeightedStats} that manages tracking state and exposes the required + * stats. + * + *

A sub-class or a different class (delegation) needs to call {@link #startStream()}, {@link + * #stopStream()}, {@link #startRequest()}, and {@link #stopRequest(long)} to drive state tracking. + * + * @since 1.1 + * @see WeightedStatsRequestInterceptor + */ +public class BaseWeightedStats implements WeightedStats { + + private static final double DEFAULT_LOWER_QUANTILE = 0.5; + private static final double DEFAULT_HIGHER_QUANTILE = 0.8; + private static final int INACTIVITY_FACTOR = 500; + private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = + Clock.unit().convert(1L, TimeUnit.SECONDS); + + private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; + + private final Quantile lowerQuantile; + private final Quantile higherQuantile; + private final Ewma availabilityPercentage; + private final Median median; + private final Ewma interArrivalTime; + + private final long tau; + private final long inactivityFactor; + + private long errorStamp; // last we got an error + private long stamp; // last timestamp we sent a request + private long stamp0; // last timestamp we sent a request or receive a response + private long duration; // instantaneous cumulative duration + + private volatile int pendingRequests; // instantaneous rate + private static final AtomicIntegerFieldUpdater PENDING_REQUESTS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingRequests"); + private volatile int pendingStreams; // number of active streams + private static final AtomicIntegerFieldUpdater PENDING_STREAMS = + AtomicIntegerFieldUpdater.newUpdater(BaseWeightedStats.class, "pendingStreams"); + + protected BaseWeightedStats() { + this( + new FrugalQuantile(DEFAULT_LOWER_QUANTILE), + new FrugalQuantile(DEFAULT_HIGHER_QUANTILE), + INACTIVITY_FACTOR); + } + + private BaseWeightedStats( + Quantile lowerQuantile, Quantile higherQuantile, long inactivityFactor) { + this.lowerQuantile = lowerQuantile; + this.higherQuantile = higherQuantile; + this.inactivityFactor = inactivityFactor; + + long now = Clock.now(); + this.stamp = now; + this.errorStamp = now; + this.stamp0 = now; + this.duration = 0L; + this.pendingRequests = 0; + this.median = new Median(); + this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); + this.availabilityPercentage = new Ewma(5, TimeUnit.SECONDS, 1.0); + this.tau = Clock.unit().convert((long) (5 / Math.log(2)), TimeUnit.SECONDS); + } + + @Override + public double lowerQuantileLatency() { + return lowerQuantile.estimation(); + } + + @Override + public double higherQuantileLatency() { + return higherQuantile.estimation(); + } + + @Override + public int pending() { + return pendingRequests + pendingStreams; + } + + @Override + public double weightedAvailability() { + if (Clock.now() - stamp > tau) { + updateAvailability(1.0); + } + return availabilityPercentage.value(); + } + + @Override + public double predictedLatency() { + final long now = Clock.now(); + final long elapsed; + + synchronized (this) { + elapsed = Math.max(now - stamp, 1L); + } + + final double latency; + final double prediction = median.estimation(); + + final int pending = this.pending(); + if (prediction == 0.0) { + if (pending == 0) { + latency = 0.0; // first request + } else { + // subsequent requests while we don't have any history + latency = STARTUP_PENALTY + pending; + } + } else if (pending == 0 && elapsed > inactivityFactor * interArrivalTime.value()) { + // if we did't see any data for a while, we decay the prediction by inserting + // artificial 0.0 into the median + median.insert(0.0); + latency = median.estimation(); + } else { + final double predicted = prediction * pending; + final double instant = instantaneous(now, pending); + + if (predicted < instant) { // NB: (0.0 < 0.0) == false + latency = instant / pending; // NB: pending never equal 0 here + } else { + // we are under the predictions + latency = prediction; + } + } + + return latency; + } + + long instantaneous(long now, int pending) { + return duration + (now - stamp0) * pending; + } + + void startStream() { + PENDING_STREAMS.incrementAndGet(this); + } + + void stopStream() { + PENDING_STREAMS.decrementAndGet(this); + } + + synchronized long startRequest() { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + interArrivalTime.insert(now - stamp); + duration += Math.max(0, now - stamp0) * pendingRequests; + PENDING_REQUESTS.lazySet(this, pendingRequests + 1); + stamp = now; + stamp0 = now; + + return now; + } + + synchronized long stopRequest(long timestamp) { + final long now = Clock.now(); + final int pendingRequests = this.pendingRequests; + + duration += Math.max(0, now - stamp0) * pendingRequests - (now - timestamp); + PENDING_REQUESTS.lazySet(this, pendingRequests - 1); + stamp0 = now; + + return now; + } + + synchronized void record(double roundTripTime) { + median.insert(roundTripTime); + lowerQuantile.insert(roundTripTime); + higherQuantile.insert(roundTripTime); + } + + void updateAvailability(double value) { + availabilityPercentage.insert(value); + if (value == 0.0d) { + synchronized (this) { + errorStamp = Clock.now(); + } + } + } + + @Override + public String toString() { + return "Stats{" + + "lowerQuantile=" + + lowerQuantile.estimation() + + ", higherQuantile=" + + higherQuantile.estimation() + + ", inactivityFactor=" + + inactivityFactor + + ", tau=" + + tau + + ", errorPercentage=" + + availabilityPercentage.value() + + ", pending=" + + pendingRequests + + ", errorStamp=" + + errorStamp + + ", stamp=" + + stamp + + ", stamp0=" + + stamp0 + + ", duration=" + + duration + + ", median=" + + median.estimation() + + ", interArrivalTime=" + + interArrivalTime.value() + + ", pendingStreams=" + + pendingStreams + + ", availability=" + + availabilityPercentage.value() + + '}'; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java new file mode 100644 index 000000000..528f4f896 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/ClientLoadbalanceStrategy.java @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.core.RSocketConnector; +import io.rsocket.plugins.InterceptorRegistry; + +/** + * A {@link LoadbalanceStrategy} with an interest in configuring the {@link RSocketConnector} for + * connecting to load-balance targets in order to hook into request lifecycle and track usage + * statistics. + * + *

Currently this callback interface is supported for strategies configured in {@link + * LoadbalanceRSocketClient}. + * + * @since 1.1 + */ +public interface ClientLoadbalanceStrategy extends LoadbalanceStrategy { + + /** + * Initialize the connector, for example using the {@link InterceptorRegistry}, to intercept + * requests. + * + * @param connector the connector to configure + */ + void initialize(RSocketConnector connector); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java new file mode 100644 index 000000000..0f87f6510 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Ewma.java @@ -0,0 +1,71 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +import io.rsocket.util.Clock; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +/** + * Compute the exponential weighted moving average of a series of values. The time at which you + * insert the value into `Ewma` is used to compute a weight (recent points are weighted higher). The + * parameter for defining the convergence speed (like most decay process) is the half-life. + * + *

e.g. with a half-life of 10 unit, if you insert 100 at t=0 and 200 at t=10 the ewma will be + * equal to (200 - 100)/2 = 150 (half of the distance between the new and the old value) + */ +class Ewma { + + final long tau; + + volatile long stamp; + static final AtomicLongFieldUpdater STAMP = + AtomicLongFieldUpdater.newUpdater(Ewma.class, "stamp"); + volatile double ewma; + + public Ewma(long halfLife, TimeUnit unit, double initialValue) { + this.tau = Clock.unit().convert((long) (halfLife / Math.log(2)), unit); + + this.ewma = initialValue; + + STAMP.lazySet(this, 0L); + } + + public synchronized void insert(double x) { + final long now = Clock.now(); + final double elapsed = Math.max(0, now - stamp); + + STAMP.lazySet(this, now); + + double w = Math.exp(-elapsed / tau); + ewma = w * ewma + (1.0 - w) * x; + } + + public synchronized void reset(double value) { + stamp = 0L; + ewma = value; + } + + public double value() { + return ewma; + } + + @Override + public String toString() { + return "Ewma(value=" + ewma + ", age=" + (Clock.now() - stamp) + ")"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java new file mode 100644 index 000000000..6c2b9c3ea --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FluxDeferredResolution.java @@ -0,0 +1,228 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.BiConsumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +abstract class FluxDeferredResolution extends Flux + implements CoreSubscriber, Subscription, BiConsumer, Scannable { + + final ResolvingOperator parent; + final INPUT fluxOrPayload; + final FrameType requestType; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(FluxDeferredResolution.class, "requested"); + + static final long STATE_UNSUBSCRIBED = -1; + static final long STATE_SUBSCRIBER_SET = 0; + static final long STATE_SUBSCRIBED = -2; + static final long STATE_TERMINATED = Long.MIN_VALUE; + + Subscription s; + CoreSubscriber actual; + boolean done; + + FluxDeferredResolution(ResolvingOperator parent, INPUT fluxOrPayload, FrameType requestType) { + this.parent = parent; + this.fluxOrPayload = fluxOrPayload; + this.requestType = requestType; + + REQUESTED.lazySet(this, STATE_UNSUBSCRIBED); + } + + @Override + public final void subscribe(CoreSubscriber actual) { + if (this.requested == STATE_UNSUBSCRIBED + && REQUESTED.compareAndSet(this, STATE_UNSUBSCRIBED, STATE_SUBSCRIBER_SET)) { + + actual.onSubscribe(this); + + if (this.requested == STATE_TERMINATED) { + return; + } + + this.actual = actual; + this.parent.observe(this); + } else { + Operators.error(actual, new IllegalStateException("Only a single Subscriber allowed")); + } + } + + @Override + public final Context currentContext() { + return this.actual.currentContext(); + } + + @Nullable + @Override + public final Object scanUnsafe(Attr key) { + long state = this.requested; + + if (key == Attr.PARENT) { + return this.s; + } + if (key == Attr.ACTUAL) { + return this.parent; + } + if (key == Attr.TERMINATED) { + return this.done; + } + if (key == Attr.CANCELLED) { + return state == STATE_TERMINATED; + } + + return null; + } + + @Override + public final void onSubscribe(Subscription s) { + final long state = this.requested; + Subscription a = this.s; + if (state == STATE_TERMINATED) { + s.cancel(); + return; + } + if (a != null) { + s.cancel(); + return; + } + + long r; + long accumulated = 0; + for (; ; ) { + r = this.requested; + + if (r == STATE_TERMINATED || r == STATE_SUBSCRIBED) { + s.cancel(); + return; + } + + this.s = s; + + long toRequest = r - accumulated; + if (toRequest > 0) { // if there is something, + s.request(toRequest); // then we do a request on the given subscription + } + accumulated = r; + + if (REQUESTED.compareAndSet(this, r, STATE_SUBSCRIBED)) { + return; + } + } + } + + @Override + public final void onNext(Payload payload) { + this.actual.onNext(payload); + } + + @Override + public final void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + this.done = true; + this.actual.onError(t); + } + + @Override + public final void onComplete() { + if (this.done) { + return; + } + + this.done = true; + this.actual.onComplete(); + } + + @Override + public final void request(long n) { + if (Operators.validate(n)) { + long r = this.requested; // volatile read beforehand + + if (r > STATE_SUBSCRIBED) { // works only in case onSubscribe has not happened + long u; + for (; ; ) { // normal CAS loop with overflow protection + if (r == Long.MAX_VALUE) { + // if r == Long.MAX_VALUE then we dont care and we can loose this + // request just in case of racing + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + // Means increment happened before onSubscribe + return; + } else { + // Means increment happened after onSubscribe + + // update new state to see what exactly happened (onSubscribe |cancel | requestN) + r = this.requested; + + // check state (expect -1 | -2 to exit, otherwise repeat) + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_TERMINATED) { // if canceled, just exit + return; + } + + // if onSubscribe -> subscription exists (and we sure of that because volatile read + // after volatile write) so we can execute requestN on the subscription + this.s.request(n); + } + } + + public final void cancel() { + long state = REQUESTED.getAndSet(this, STATE_TERMINATED); + if (state == STATE_TERMINATED) { + return; + } + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + if (requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + } + } + + boolean isTerminated() { + return this.requested == STATE_TERMINATED; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java new file mode 100644 index 000000000..cdbdc19b3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/FrugalQuantile.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +import java.util.SplittableRandom; + +/** + * Reference: Ma, Qiang, S. Muthukrishnan, and Mark Sandler. "Frugal Streaming for Estimating + * Quantiles." Space-Efficient Data Structures, Streams, and Algorithms. Springer Berlin Heidelberg, + * 2013. 77-96. + * + *

More info: http://blog.aggregateknowledge.com/2013/09/16/sketch-of-the-day-frugal-streaming/ + */ +class FrugalQuantile implements Quantile { + final double increment; + final SplittableRandom rnd; + + int step; + int sign; + double quantile; + + volatile double estimate; + + public FrugalQuantile(double quantile, double increment) { + this.increment = increment; + this.quantile = quantile; + this.estimate = 0.0; + this.step = 1; + this.sign = 0; + this.rnd = new SplittableRandom(System.nanoTime()); + } + + public FrugalQuantile(double quantile) { + this(quantile, 1.0); + } + + public synchronized void reset(double quantile) { + this.quantile = quantile; + this.estimate = 0.0; + this.step = 1; + this.sign = 0; + } + + public double estimation() { + return estimate; + } + + @Override + public synchronized void insert(double x) { + if (sign == 0) { + estimate = x; + sign = 1; + } else { + final double v = rnd.nextDouble(); + final double estimate = this.estimate; + + if (x > estimate && v > (1 - quantile)) { + higher(x); + } else if (x < estimate && v > quantile) { + lower(x); + } + } + } + + private void higher(double x) { + double estimate = this.estimate; + + step += sign * increment; + + if (step > 0) { + estimate += step; + } else { + estimate += 1; + } + + if (estimate > x) { + step += (x - estimate); + estimate = x; + } + + if (sign < 0) { + step = 1; + } + + sign = 1; + + this.estimate = estimate; + } + + private void lower(double x) { + double estimate = this.estimate; + + step -= sign * increment; + + if (step > 0) { + estimate -= step; + } else { + estimate--; + } + + if (estimate < x) { + step += (estimate - x); + estimate = x; + } + + if (sign > 0) { + step = 1; + } + + sign = -1; + + this.estimate = estimate; + } + + @Override + public String toString() { + return "FrugalQuantile(q=" + quantile + ", v=" + estimate + ")"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java new file mode 100644 index 000000000..eebf82fe9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Int2LongHashMap.java @@ -0,0 +1,1005 @@ +/* + * Copyright 2014-2020 Real Logic Limited. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import java.io.Serializable; +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.IntToLongFunction; +import reactor.util.annotation.Nullable; + +/** A open addressing with linear probing hash map specialised for primitive key and value pairs. */ +class Int2LongHashMap implements Map, Serializable { + static final float DEFAULT_LOAD_FACTOR = 0.55f; + static final int MIN_CAPACITY = 8; + private static final long serialVersionUID = -690554872053575793L; + + private final float loadFactor; + private final long missingValue; + private int resizeThreshold; + private int size = 0; + private final boolean shouldAvoidAllocation; + + private long[] entries; + private KeySet keySet; + private ValueCollection values; + private EntrySet entrySet; + + /** @param missingValue for the map that represents null. */ + public Int2LongHashMap(final long missingValue) { + this(MIN_CAPACITY, DEFAULT_LOAD_FACTOR, missingValue); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + */ + public Int2LongHashMap( + final int initialCapacity, final float loadFactor, final long missingValue) { + this(initialCapacity, loadFactor, missingValue, true); + } + + /** + * @param initialCapacity for the map to override {@link #MIN_CAPACITY} + * @param loadFactor for the map to override {@link #DEFAULT_LOAD_FACTOR}. + * @param missingValue for the map that represents null. + * @param shouldAvoidAllocation should allocation be avoided by caching iterators and map entries. + */ + public Int2LongHashMap( + final int initialCapacity, + final float loadFactor, + final long missingValue, + final boolean shouldAvoidAllocation) { + validateLoadFactor(loadFactor); + + this.loadFactor = loadFactor; + this.missingValue = missingValue; + this.shouldAvoidAllocation = shouldAvoidAllocation; + + capacity(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, initialCapacity))); + } + + /** + * The value to be used as a null marker in the map. + * + * @return value to be used as a null marker in the map. + */ + public long missingValue() { + return missingValue; + } + + /** + * Get the load factor applied for resize operations. + * + * @return the load factor applied for resize operations. + */ + public float loadFactor() { + return loadFactor; + } + + /** + * Get the total capacity for the map to which the load factor will be a fraction of. + * + * @return the total capacity for the map. + */ + public int capacity() { + return entries.length >> 1; + } + + /** + * Get the actual threshold which when reached the map will resize. This is a function of the + * current capacity and load factor. + * + * @return the threshold when the map will resize. + */ + public int resizeThreshold() { + return resizeThreshold; + } + + /** {@inheritDoc} */ + public int size() { + return size; + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return size == 0; + } + + /** + * Get a value using provided key avoiding boxing. + * + * @param key lookup key. + * @return value associated with the key or {@link #missingValue()} if key is not found in the + * map. + */ + public long get(final int key) { + final int mask = entries.length - 1; + int index = evenHash(key, mask); + + long value = missingValue; + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + value = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + return value; + } + + /** + * Put a key value pair in the map. + * + * @param key lookup key + * @param value new value, must not be {@link #missingValue()} + * @return previous value associated with the key, or {@link #missingValue()} if none found + * @throws IllegalArgumentException if value is {@link #missingValue()} + */ + public long put(final int key, final long value) { + if (value == missingValue) { + throw new IllegalArgumentException("cannot accept missingValue"); + } + + final int mask = entries.length - 1; + int index = evenHash(key, mask); + long oldValue = missingValue; + + while (entries[index + 1] != missingValue) { + if (entries[index] == key) { + oldValue = entries[index + 1]; + break; + } + + index = next(index, mask); + } + + if (oldValue == missingValue) { + ++size; + entries[index] = key; + } + + entries[index + 1] = value; + + increaseCapacity(); + + return oldValue; + } + + private void increaseCapacity() { + if (size > resizeThreshold) { + // entries.length = 2 * capacity + final int newCapacity = entries.length; + rehash(newCapacity); + } + } + + private void rehash(final int newCapacity) { + final long[] oldEntries = entries; + final int length = entries.length; + + capacity(newCapacity); + + final long[] newEntries = entries; + final int mask = entries.length - 1; + + for (int keyIndex = 0; keyIndex < length; keyIndex += 2) { + final long value = oldEntries[keyIndex + 1]; + if (value != missingValue) { + final int key = (int) oldEntries[keyIndex]; + int index = evenHash(key, mask); + + while (newEntries[index + 1] != missingValue) { + index = next(index, mask); + } + + newEntries[index] = key; + newEntries[index + 1] = value; + } + } + } + + /** + * Int primitive specialised containsKey. + * + * @param key the key to check. + * @return true if the map contains key as a key, false otherwise. + */ + public boolean containsKey(final int key) { + return get(key) != missingValue; + } + + /** + * Does the map contain the value. + * + * @param value to be tested against contained values. + * @return true if contained otherwise value. + */ + public boolean containsValue(final long value) { + boolean found = false; + if (value != missingValue) { + final int length = entries.length; + int remaining = size; + + for (int valueIndex = 1; remaining > 0 && valueIndex < length; valueIndex += 2) { + if (missingValue != entries[valueIndex]) { + if (value == entries[valueIndex]) { + found = true; + break; + } + --remaining; + } + } + } + + return found; + } + + /** {@inheritDoc} */ + public void clear() { + if (size > 0) { + Arrays.fill(entries, missingValue); + size = 0; + } + } + + /** + * Compact the backing arrays by rehashing with a capacity just larger than current size and + * giving consideration to the load factor. + */ + public void compact() { + final int idealCapacity = (int) Math.round(size() * (1.0d / loadFactor)); + rehash(findNextPositivePowerOfTwo(Math.max(MIN_CAPACITY, idealCapacity))); + } + + /** + * Primitive specialised version of {@link #computeIfAbsent(Object, Function)} + * + * @param key to search on. + * @param mappingFunction to provide a value if the get returns null. + * @return the value if found otherwise the missing value. + */ + public long computeIfAbsent(final int key, final IntToLongFunction mappingFunction) { + long value = get(key); + if (value == missingValue) { + value = mappingFunction.applyAsLong(key); + if (value != missingValue) { + put(key, value); + } + } + + return value; + } + + // ---------------- Boxed Versions Below ---------------- + + /** {@inheritDoc} */ + @Nullable + public Long get(final Object key) { + return valOrNull(get((int) key)); + } + + /** {@inheritDoc} */ + public Long put(final Integer key, final Long value) { + return valOrNull(put((int) key, (long) value)); + } + + /** {@inheritDoc} */ + public boolean containsKey(final Object key) { + return containsKey((int) key); + } + + /** {@inheritDoc} */ + public boolean containsValue(final Object value) { + return containsValue((long) value); + } + + /** {@inheritDoc} */ + public void putAll(final Map map) { + for (final Map.Entry entry : map.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + /** {@inheritDoc} */ + public KeySet keySet() { + if (null == keySet) { + keySet = new KeySet(); + } + + return keySet; + } + + /** {@inheritDoc} */ + public ValueCollection values() { + if (null == values) { + values = new ValueCollection(); + } + + return values; + } + + /** {@inheritDoc} */ + public EntrySet entrySet() { + if (null == entrySet) { + entrySet = new EntrySet(); + } + + return entrySet; + } + + /** {@inheritDoc} */ + @Nullable + public Long remove(final Object key) { + return valOrNull(remove((int) key)); + } + + /** + * Remove value from the map using given key avoiding boxing. + * + * @param key whose mapping is to be removed from the map. + * @return removed value or {@link #missingValue()} if key was not found in the map. + */ + public long remove(final int key) { + final int mask = entries.length - 1; + int keyIndex = evenHash(key, mask); + + long oldValue = missingValue; + while (entries[keyIndex + 1] != missingValue) { + if (entries[keyIndex] == key) { + oldValue = entries[keyIndex + 1]; + entries[keyIndex + 1] = missingValue; + size--; + + compactChain(keyIndex); + + break; + } + + keyIndex = next(keyIndex, mask); + } + + return oldValue; + } + + @SuppressWarnings("FinalParameters") + private void compactChain(int deleteKeyIndex) { + final int mask = entries.length - 1; + int keyIndex = deleteKeyIndex; + + while (true) { + keyIndex = next(keyIndex, mask); + if (entries[keyIndex + 1] == missingValue) { + break; + } + + final int hash = evenHash((int) entries[keyIndex], mask); + + if ((keyIndex < hash && (hash <= deleteKeyIndex || deleteKeyIndex <= keyIndex)) + || (hash <= deleteKeyIndex && deleteKeyIndex <= keyIndex)) { + entries[deleteKeyIndex] = entries[keyIndex]; + entries[deleteKeyIndex + 1] = entries[keyIndex + 1]; + + entries[keyIndex + 1] = missingValue; + deleteKeyIndex = keyIndex; + } + } + } + + /** + * Get the minimum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the minimum value stored in the map. + */ + public long minValue() { + final long missingValue = this.missingValue; + long min = size == 0 ? missingValue : Long.MAX_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + min = Math.min(min, value); + } + } + + return min; + } + + /** + * Get the maximum value stored in the map. If the map is empty then it will return {@link + * #missingValue()} + * + * @return the maximum value stored in the map. + */ + public long maxValue() { + final long missingValue = this.missingValue; + long max = size == 0 ? missingValue : Long.MIN_VALUE; + final int length = entries.length; + + for (int valueIndex = 1; valueIndex < length; valueIndex += 2) { + final long value = entries[valueIndex]; + if (value != missingValue) { + max = Math.max(max, value); + } + } + + return max; + } + + /** {@inheritDoc} */ + public String toString() { + if (isEmpty()) { + return "{}"; + } + + final EntryIterator entryIterator = new EntryIterator(); + entryIterator.reset(); + + final StringBuilder sb = new StringBuilder().append('{'); + while (true) { + entryIterator.next(); + sb.append(entryIterator.getIntKey()).append('=').append(entryIterator.getLongValue()); + if (!entryIterator.hasNext()) { + return sb.append('}').toString(); + } + sb.append(',').append(' '); + } + } + + /** + * Primitive specialised version of {@link #replace(Object, Object)} + * + * @param key key with which the specified value is associated + * @param value value to be associated with the specified key + * @return the previous value associated with the specified key, or {@link #missingValue()} if + * there was no mapping for the key. + */ + public long replace(final int key, final long value) { + long currentValue = get(key); + if (currentValue != missingValue) { + currentValue = put(key, value); + } + + return currentValue; + } + + /** + * Primitive specialised version of {@link #replace(Object, Object, Object)} + * + * @param key key with which the specified value is associated + * @param oldValue value expected to be associated with the specified key + * @param newValue value to be associated with the specified key + * @return {@code true} if the value was replaced + */ + public boolean replace(final int key, final long oldValue, final long newValue) { + final long curValue = get(key); + if (curValue != oldValue || curValue == missingValue) { + return false; + } + + put(key, newValue); + + return true; + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Map)) { + return false; + } + + final Map that = (Map) o; + + return size == that.size() && entrySet().equals(that.entrySet()); + } + + public int hashCode() { + return entrySet().hashCode(); + } + + private static int next(final int index, final int mask) { + return (index + 2) & mask; + } + + private void capacity(final int newCapacity) { + final int entriesLength = newCapacity * 2; + if (entriesLength < 0) { + throw new IllegalStateException("max capacity reached at size=" + size); + } + + /*@DoNotSub*/ resizeThreshold = (int) (newCapacity * loadFactor); + entries = new long[entriesLength]; + Arrays.fill(entries, missingValue); + } + + @Nullable + private Long valOrNull(final long value) { + return value == missingValue ? null : value; + } + + // ---------------- Utility Classes ---------------- + + /** Base iterator implementation. */ + abstract class AbstractIterator implements Serializable { + private static final long serialVersionUID = 5262459454112462433L; + /** Is current position valid. */ + protected boolean isPositionValid = false; + + private int remaining; + private int positionCounter; + private int stopCounter; + + final void reset() { + isPositionValid = false; + remaining = Int2LongHashMap.this.size; + final long missingValue = Int2LongHashMap.this.missingValue; + final long[] entries = Int2LongHashMap.this.entries; + final int capacity = entries.length; + + int keyIndex = capacity; + if (entries[capacity - 1] != missingValue) { + for (int i = 1; i < capacity; i += 2) { + if (entries[i] == missingValue) { + keyIndex = i - 1; + break; + } + } + } + + stopCounter = keyIndex; + positionCounter = keyIndex + capacity; + } + + /** + * Returns position of the key of the current entry. + * + * @return key position. + */ + protected final int keyPosition() { + return positionCounter & entries.length - 1; + } + + /** + * Number of remaining elements. + * + * @return number of remaining elements. + */ + public int remaining() { + return remaining; + } + + /** + * Check if there are more elements remaining. + * + * @return {@code true} if {@code remaining > 0}. + */ + public boolean hasNext() { + return remaining > 0; + } + + /** + * Advance to the next entry. + * + * @throws NoSuchElementException if no more entries available. + */ + protected final void findNext() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + final long[] entries = Int2LongHashMap.this.entries; + final long missingValue = Int2LongHashMap.this.missingValue; + final int mask = entries.length - 1; + + for (int keyIndex = positionCounter - 2; keyIndex >= stopCounter; keyIndex -= 2) { + final int index = keyIndex & mask; + if (entries[index + 1] != missingValue) { + isPositionValid = true; + positionCounter = keyIndex; + --remaining; + return; + } + } + + isPositionValid = false; + throw new IllegalStateException(); + } + + /** {@inheritDoc} */ + public void remove() { + if (isPositionValid) { + final int position = keyPosition(); + entries[position + 1] = missingValue; + --size; + + compactChain(position); + + isPositionValid = false; + } else { + throw new IllegalStateException(); + } + } + } + + /** Iterator over keys which supports access to unboxed keys via {@link #nextValue()}. */ + public final class KeyIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = 9151493609653852972L; + + public Integer next() { + return nextValue(); + } + + /** + * Return next key. + * + * @return next key. + */ + public int nextValue() { + findNext(); + return (int) entries[keyPosition()]; + } + } + + /** Iterator over values which supports access to unboxed values. */ + public final class ValueIterator extends AbstractIterator + implements Iterator, Serializable { + private static final long serialVersionUID = -5670291734793552927L; + + public Long next() { + return nextValue(); + } + + /** + * Return next value. + * + * @return next value. + */ + public long nextValue() { + findNext(); + return entries[keyPosition() + 1]; + } + } + + /** Iterator over entries which supports access to unboxed keys and values. */ + public final class EntryIterator extends AbstractIterator + implements Iterator>, Entry, Serializable { + private static final long serialVersionUID = 1744408438593481051L; + + public Integer getKey() { + return getIntKey(); + } + + /** + * Returns the key of the current entry. + * + * @return the key. + */ + public int getIntKey() { + return (int) entries[keyPosition()]; + } + + public Long getValue() { + return getLongValue(); + } + + /** + * Returns the value of the current entry. + * + * @return the value. + */ + public long getLongValue() { + return entries[keyPosition() + 1]; + } + + public Long setValue(final Long value) { + return setValue(value.longValue()); + } + + /** + * Sets the value of the current entry. + * + * @param value to be set. + * @return previous value of the entry. + */ + public long setValue(final long value) { + if (!isPositionValid) { + throw new IllegalStateException(); + } + + if (missingValue == value) { + throw new IllegalArgumentException(); + } + + final int keyPosition = keyPosition(); + final long prevValue = entries[keyPosition + 1]; + entries[keyPosition + 1] = value; + return prevValue; + } + + public Entry next() { + findNext(); + + if (shouldAvoidAllocation) { + return this; + } + + return allocateDuplicateEntry(); + } + + private Entry allocateDuplicateEntry() { + return new MapEntry(getIntKey(), getLongValue()); + } + + /** {@inheritDoc} */ + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + /** {@inheritDoc} */ + public boolean equals(final Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Entry)) { + return false; + } + + final Entry that = (Entry) o; + + return Objects.equals(getKey(), that.getKey()) && Objects.equals(getValue(), that.getValue()); + } + + /** An {@link java.util.Map.Entry} implementation. */ + public final class MapEntry implements Entry { + private final int k; + private final long v; + + /** + * Constructs entry with given key and value. + * + * @param k key. + * @param v value. + */ + public MapEntry(final int k, final long v) { + this.k = k; + this.v = v; + } + + public Integer getKey() { + return k; + } + + public Long getValue() { + return v; + } + + public Long setValue(final Long value) { + return Int2LongHashMap.this.put(k, value.longValue()); + } + + public int hashCode() { + return Integer.hashCode(getIntKey()) ^ Long.hashCode(getLongValue()); + } + + public boolean equals(final Object o) { + if (!(o instanceof Map.Entry)) { + return false; + } + + final Entry e = (Entry) o; + + return (e.getKey() != null && e.getValue() != null) + && (e.getKey().equals(k) && e.getValue().equals(v)); + } + + public String toString() { + return k + "=" + v; + } + } + } + + /** Set of keys which supports optional cached iterators to avoid allocation. */ + public final class KeySet extends AbstractSet implements Serializable { + private static final long serialVersionUID = -7645453993079742625L; + private final KeyIterator keyIterator = shouldAvoidAllocation ? new KeyIterator() : null; + + /** {@inheritDoc} */ + public KeyIterator iterator() { + KeyIterator keyIterator = this.keyIterator; + if (null == keyIterator) { + keyIterator = new KeyIterator(); + } + + keyIterator.reset(); + + return keyIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((int) o); + } + + /** + * Checks if key is contained in the map without boxing. + * + * @param key to check. + * @return {@code true} if key is contained in this map. + */ + public boolean contains(final int key) { + return containsKey(key); + } + } + + /** Collection of values which supports optionally cached iterators to avoid allocation. */ + public final class ValueCollection extends AbstractCollection implements Serializable { + private static final long serialVersionUID = -8925598924781601919L; + private final ValueIterator valueIterator = shouldAvoidAllocation ? new ValueIterator() : null; + + /** {@inheritDoc} */ + public ValueIterator iterator() { + ValueIterator valueIterator = this.valueIterator; + if (null == valueIterator) { + valueIterator = new ValueIterator(); + } + + valueIterator.reset(); + + return valueIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + return contains((long) o); + } + + /** + * Checks if the value is contained in the map. + * + * @param value to be checked. + * @return {@code true} if value is contained in this map. + */ + public boolean contains(final long value) { + return containsValue(value); + } + } + + /** Set of entries which supports optionally cached iterators to avoid allocation. */ + public final class EntrySet extends AbstractSet> + implements Serializable { + private static final long serialVersionUID = 63641283589916174L; + private final EntryIterator entryIterator = shouldAvoidAllocation ? new EntryIterator() : null; + + /** {@inheritDoc} */ + public EntryIterator iterator() { + EntryIterator entryIterator = this.entryIterator; + if (null == entryIterator) { + entryIterator = new EntryIterator(); + } + + entryIterator.reset(); + + return entryIterator; + } + + /** {@inheritDoc} */ + public int size() { + return Int2LongHashMap.this.size(); + } + + /** {@inheritDoc} */ + public boolean isEmpty() { + return Int2LongHashMap.this.isEmpty(); + } + + /** {@inheritDoc} */ + public void clear() { + Int2LongHashMap.this.clear(); + } + + /** {@inheritDoc} */ + public boolean contains(final Object o) { + if (!(o instanceof Entry)) { + return false; + } + final Entry entry = (Entry) o; + final Long value = get(entry.getKey()); + + return value != null && value.equals(entry.getValue()); + } + + /** {@inheritDoc} */ + public Object[] toArray() { + return toArray(new Object[size()]); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + public T[] toArray(final T[] a) { + final T[] array = + a.length >= size + ? a + : (T[]) java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), size); + final EntryIterator it = iterator(); + + for (int i = 0; i < array.length; i++) { + if (it.hasNext()) { + it.next(); + array[i] = (T) it.allocateDuplicateEntry(); + } else { + array[i] = null; + break; + } + } + + return array; + } + } + + private static int evenHash(final int value, final int mask) { + final int hash = (value << 1) - (value << 8); + + return hash & mask; + } + + private static void validateLoadFactor(final float loadFactor) { + if (loadFactor < 0.1f || loadFactor > 0.9f) { + throw new IllegalArgumentException( + "load factor must be in the range of 0.1 to 0.9: " + loadFactor); + } + } + + private static int findNextPositivePowerOfTwo(final int value) { + return 1 << (Integer.SIZE - Integer.numberOfLeadingZeros(value - 1)); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java new file mode 100644 index 000000000..d59cbb86e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java @@ -0,0 +1,195 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import java.util.List; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * An implementation of {@link RSocketClient} backed by a pool of {@code RSocket} instances and + * using a {@link LoadbalanceStrategy} to select the {@code RSocket} to use for a given request. + * + * @since 1.1 + */ +public class LoadbalanceRSocketClient implements RSocketClient { + + private final RSocketPool rSocketPool; + + private LoadbalanceRSocketClient(RSocketPool rSocketPool) { + this.rSocketPool = rSocketPool; + } + + @Override + public Mono onClose() { + return rSocketPool.onClose(); + } + + @Override + public boolean connect() { + return rSocketPool.connect(); + } + + /** Return {@code Mono} that selects an RSocket from the underlying pool. */ + @Override + public Mono source() { + return Mono.fromSupplier(rSocketPool::select); + } + + @Override + public Mono fireAndForget(Mono payloadMono) { + return payloadMono.flatMap(p -> rSocketPool.select().fireAndForget(p)); + } + + @Override + public Mono requestResponse(Mono payloadMono) { + return payloadMono.flatMap(p -> rSocketPool.select().requestResponse(p)); + } + + @Override + public Flux requestStream(Mono payloadMono) { + return payloadMono.flatMapMany(p -> rSocketPool.select().requestStream(p)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return source().flatMapMany(rSocket -> rSocket.requestChannel(payloads)); + } + + @Override + public Mono metadataPush(Mono payloadMono) { + return payloadMono.flatMap(p -> rSocketPool.select().metadataPush(p)); + } + + @Override + public void dispose() { + rSocketPool.dispose(); + } + + /** + * Shortcut to create an {@link LoadbalanceRSocketClient} with round-robin load balancing. + * Effectively a shortcut for: + * + *

+   * LoadbalanceRSocketClient.builder(targetPublisher)
+   *    .connector(RSocketConnector.create())
+   *    .build();
+   * 
+ * + * @param connector a "template" for connecting to load balance targets + * @param targetPublisher refreshes the list of load balance targets periodically + * @return the created client instance + */ + public static LoadbalanceRSocketClient create( + RSocketConnector connector, Publisher> targetPublisher) { + return builder(targetPublisher).connector(connector).build(); + } + + /** + * Return a builder for a {@link LoadbalanceRSocketClient}. + * + * @param targetPublisher refreshes the list of load balance targets periodically + * @return the created builder + */ + public static Builder builder(Publisher> targetPublisher) { + return new Builder(targetPublisher); + } + + /** Builder for creating an {@link LoadbalanceRSocketClient}. */ + public static class Builder { + + private final Publisher> targetPublisher; + + @Nullable private RSocketConnector connector; + + @Nullable LoadbalanceStrategy loadbalanceStrategy; + + Builder(Publisher> targetPublisher) { + this.targetPublisher = targetPublisher; + } + + /** + * Configure the "template" connector to use for connecting to load balance targets. To + * establish a connection, the {@link LoadbalanceTarget#getTransport() ClientTransport} + * contained in each target is passed to the connector's {@link + * RSocketConnector#connect(ClientTransport) connect} method and thus the same connector with + * the same settings applies to all targets. + * + *

By default this is initialized with {@link RSocketConnector#create()}. + * + * @param connector the connector to use as a template + */ + public Builder connector(RSocketConnector connector) { + this.connector = connector; + return this; + } + + /** + * Configure {@link RoundRobinLoadbalanceStrategy} as the strategy to use to select targets. + * + *

This is the strategy used by default. + */ + public Builder roundRobinLoadbalanceStrategy() { + this.loadbalanceStrategy = new RoundRobinLoadbalanceStrategy(); + return this; + } + + /** + * Configure {@link WeightedLoadbalanceStrategy} as the strategy to use to select targets. + * + *

By default, {@link RoundRobinLoadbalanceStrategy} is used. + */ + public Builder weightedLoadbalanceStrategy() { + this.loadbalanceStrategy = WeightedLoadbalanceStrategy.create(); + return this; + } + + /** + * Configure the {@link LoadbalanceStrategy} to use. + * + *

By default, {@link RoundRobinLoadbalanceStrategy} is used. + */ + public Builder loadbalanceStrategy(LoadbalanceStrategy strategy) { + this.loadbalanceStrategy = strategy; + return this; + } + + /** Build the {@link LoadbalanceRSocketClient} instance. */ + public LoadbalanceRSocketClient build() { + final RSocketConnector connector = + (this.connector != null ? this.connector : RSocketConnector.create()); + + final LoadbalanceStrategy strategy = + (this.loadbalanceStrategy != null + ? this.loadbalanceStrategy + : new RoundRobinLoadbalanceStrategy()); + + if (strategy instanceof ClientLoadbalanceStrategy) { + ((ClientLoadbalanceStrategy) strategy).initialize(connector); + } + + return new LoadbalanceRSocketClient( + new RSocketPool(connector, this.targetPublisher, strategy)); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java new file mode 100644 index 000000000..5662448e7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceStrategy.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import java.util.List; + +/** + * Strategy to select an {@link RSocket} given a list of instances for load-balancing purposes. A + * simple implementation might go in round-robin fashion while a more sophisticated strategy might + * check availability, track usage stats, and so on. + * + * @since 1.1 + */ +@FunctionalInterface +public interface LoadbalanceStrategy { + + /** + * Select an {@link RSocket} from the given non-empty list. + * + * @param sockets the list to choose from + * @return the selected instance + */ + RSocket select(List sockets); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java new file mode 100644 index 000000000..3b5d71e4e --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceTarget.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import org.reactivestreams.Publisher; + +/** + * Representation for a load-balance target used as input to {@link LoadbalanceRSocketClient} that + * in turn maintains and peridodically updates a list of current load-balance targets. The {@link + * #getKey()} is used to identify a target uniquely while the {@link #getTransport() transport} is + * used to connect to the target server. + * + * @since 1.1 + * @see LoadbalanceRSocketClient#create(RSocketConnector, Publisher) + */ +public class LoadbalanceTarget { + + final String key; + final ClientTransport transport; + + private LoadbalanceTarget(String key, ClientTransport transport) { + this.key = key; + this.transport = transport; + } + + /** Return the key that identifies this target uniquely. */ + public String getKey() { + return key; + } + + /** Return the transport to use to connect to the target server. */ + public ClientTransport getTransport() { + return transport; + } + + /** + * Create a new {@link LoadbalanceTarget} with the given key and {@link ClientTransport}. The key + * can be anything that identifies the target uniquely, e.g. SocketAddress, URL, and so on. + * + * @param key identifies the load-balance target uniquely + * @param transport for connecting to the target + * @return the created instance + */ + public static LoadbalanceTarget from(String key, ClientTransport transport) { + return new LoadbalanceTarget(key, transport); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } + if (other == null || getClass() != other.getClass()) { + return false; + } + LoadbalanceTarget that = (LoadbalanceTarget) other; + return key.equals(that.key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java new file mode 100644 index 000000000..5319706f9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Median.java @@ -0,0 +1,99 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +/** This implementation gives better results because it considers more data-point. */ +class Median extends FrugalQuantile { + + public Median() { + super(0.5, 1.0); + } + + public synchronized void reset() { + super.reset(0.5); + } + + @Override + public synchronized void insert(double x) { + if (sign == 0) { + estimate = x; + sign = 1; + } else { + final double estimate = this.estimate; + if (x > estimate) { + greaterThanZero(x); + } else if (x < estimate) { + lessThanZero(x); + } + } + } + + private void greaterThanZero(double x) { + double estimate = this.estimate; + + step += sign; + + if (step > 0) { + estimate += step; + } else { + estimate += 1; + } + + if (estimate > x) { + step += (x - estimate); + estimate = x; + } + + if (sign < 0) { + step = 1; + } + + sign = 1; + + this.estimate = estimate; + } + + private void lessThanZero(double x) { + double estimate = this.estimate; + + step -= sign; + + if (step > 0) { + estimate -= step; + } else { + estimate--; + } + + if (estimate < x) { + step += (estimate - x); + estimate = x; + } + + if (sign > 0) { + step = 1; + } + + sign = -1; + + this.estimate = estimate; + } + + @Override + public String toString() { + return "Median(v=" + estimate + ")"; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java new file mode 100644 index 000000000..69838f1b6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/MonoDeferredResolution.java @@ -0,0 +1,226 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.function.BiConsumer; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +abstract class MonoDeferredResolution extends Mono + implements CoreSubscriber, Subscription, Scannable, BiConsumer { + + final ResolvingOperator parent; + final Payload payload; + final FrameType requestType; + + volatile long requested; + + @SuppressWarnings("rawtypes") + static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(MonoDeferredResolution.class, "requested"); + + static final long STATE_UNSUBSCRIBED = -1; + static final long STATE_SUBSCRIBER_SET = 0; + static final long STATE_SUBSCRIBED = -2; + static final long STATE_TERMINATED = Long.MIN_VALUE; + + Subscription s; + CoreSubscriber actual; + boolean done; + + MonoDeferredResolution(ResolvingOperator parent, Payload payload, FrameType requestType) { + this.parent = parent; + this.payload = payload; + this.requestType = requestType; + + REQUESTED.lazySet(this, STATE_UNSUBSCRIBED); + } + + @Override + public final void subscribe(CoreSubscriber actual) { + if (this.requested == STATE_UNSUBSCRIBED + && REQUESTED.compareAndSet(this, STATE_UNSUBSCRIBED, STATE_SUBSCRIBER_SET)) { + + actual.onSubscribe(this); + + if (this.requested == STATE_TERMINATED) { + return; + } + + this.actual = actual; + this.parent.observe(this); + } else { + Operators.error(actual, new IllegalStateException("Only a single Subscriber allowed")); + } + } + + @Override + public final Context currentContext() { + return this.actual.currentContext(); + } + + @Nullable + @Override + public Object scanUnsafe(Attr key) { + long state = this.requested; + + if (key == Attr.PARENT) { + return this.s; + } + if (key == Attr.ACTUAL) { + return this.parent; + } + if (key == Attr.TERMINATED) { + return this.done; + } + if (key == Attr.CANCELLED) { + return state == STATE_TERMINATED; + } + + return null; + } + + @Override + public final void onSubscribe(Subscription s) { + final long state = this.requested; + Subscription a = this.s; + if (state == STATE_TERMINATED) { + s.cancel(); + return; + } + if (a != null) { + s.cancel(); + return; + } + + long r; + long accumulated = 0; + for (; ; ) { + r = this.requested; + + if (r == STATE_TERMINATED || r == STATE_SUBSCRIBED) { + s.cancel(); + return; + } + + this.s = s; + + long toRequest = r - accumulated; + if (toRequest > 0) { // if there is something, + s.request(toRequest); // then we do a request on the given subscription + } + accumulated = r; + + if (REQUESTED.compareAndSet(this, r, STATE_SUBSCRIBED)) { + return; + } + } + } + + @Override + public final void onNext(RESULT payload) { + this.actual.onNext(payload); + } + + @Override + public final void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + this.done = true; + this.actual.onError(t); + } + + @Override + public final void onComplete() { + if (this.done) { + return; + } + + this.done = true; + this.actual.onComplete(); + } + + @Override + public final void request(long n) { + if (Operators.validate(n)) { + long r = this.requested; // volatile read beforehand + + if (r > STATE_SUBSCRIBED) { // works only in case onSubscribe has not happened + long u; + for (; ; ) { // normal CAS loop with overflow protection + if (r == Long.MAX_VALUE) { + // if r == Long.MAX_VALUE then we dont care and we can loose this + // request just in case of racing + return; + } + u = Operators.addCap(r, n); + if (REQUESTED.compareAndSet(this, r, u)) { + // Means increment happened before onSubscribe + return; + } else { + // Means increment happened after onSubscribe + + // update new state to see what exactly happened (onSubscribe |cancel | requestN) + r = this.requested; + + // check state (expect -1 | -2 to exit, otherwise repeat) + if (r < 0) { + break; + } + } + } + } + + if (r == STATE_TERMINATED) { // if canceled, just exit + return; + } + + // if onSubscribe -> subscription exists (and we sure of that because volatile read + // after volatile write) so we can execute requestN on the subscription + this.s.request(n); + } + } + + public final void cancel() { + long state = REQUESTED.getAndSet(this, STATE_TERMINATED); + if (state == STATE_TERMINATED) { + return; + } + + if (state == STATE_SUBSCRIBED) { + this.s.cancel(); + } else { + this.parent.remove(this); + ReferenceCountUtil.safeRelease(this.payload); + } + } + + boolean isTerminated() { + return this.requested == STATE_TERMINATED; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java new file mode 100644 index 000000000..a77329d31 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/PooledRSocket.java @@ -0,0 +1,310 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.context.Context; + +/** Default implementation of {@link RSocket} stored in {@link RSocketPool} */ +final class PooledRSocket extends ResolvingOperator + implements CoreSubscriber, RSocket { + + final RSocketPool parent; + final Mono rSocketSource; + final LoadbalanceTarget loadbalanceTarget; + final Sinks.Empty onCloseSink; + + volatile Subscription s; + + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(PooledRSocket.class, Subscription.class, "s"); + + PooledRSocket( + RSocketPool parent, Mono rSocketSource, LoadbalanceTarget loadbalanceTarget) { + this.parent = parent; + this.rSocketSource = rSocketSource; + this.loadbalanceTarget = loadbalanceTarget; + this.onCloseSink = Sinks.unsafe().empty(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onComplete() { + final Subscription s = this.s; + final RSocket value = this.value; + + if (s == Operators.cancelledSubscription() || !S.compareAndSet(this, s, null)) { + this.doFinally(); + return; + } + + if (value == null) { + this.terminate(new IllegalStateException("Source completed empty")); + } else { + this.complete(value); + } + } + + @Override + public void onError(Throwable t) { + final Subscription s = this.s; + + if (s == Operators.cancelledSubscription() + || S.getAndSet(this, Operators.cancelledSubscription()) + == Operators.cancelledSubscription()) { + this.doFinally(); + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doFinally(); + // terminate upstream (retryBackoff has exhausted) and remove from the parent target list + this.doCleanup(t); + } + + @Override + public void onNext(RSocket value) { + if (this.s == Operators.cancelledSubscription()) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + // volatile write and check on racing + this.doFinally(); + } + + @Override + protected void doSubscribe() { + this.rSocketSource.subscribe(this); + } + + @Override + protected void doOnValueResolved(RSocket value) { + value.onClose().subscribe(null, this::doCleanup, () -> doCleanup(ON_DISPOSE)); + } + + void doCleanup(Throwable t) { + if (isDisposed()) { + return; + } + + this.terminate(t); + + final RSocketPool parent = this.parent; + for (; ; ) { + final PooledRSocket[] sockets = parent.activeSockets; + final int activeSocketsCount = sockets.length; + + int index = -1; + for (int i = 0; i < activeSocketsCount; i++) { + if (sockets[i] == this) { + index = i; + break; + } + } + + if (index == -1) { + break; + } + + final PooledRSocket[] newSockets; + if (activeSocketsCount == 1) { + newSockets = RSocketPool.EMPTY; + } else { + final int lastIndex = activeSocketsCount - 1; + + newSockets = new PooledRSocket[lastIndex]; + if (index != 0) { + System.arraycopy(sockets, 0, newSockets, 0, index); + } + + if (index != lastIndex) { + System.arraycopy(sockets, index + 1, newSockets, index, lastIndex - index); + } + } + + if (RSocketPool.ACTIVE_SOCKETS.compareAndSet(parent, sockets, newSockets)) { + break; + } + } + + if (t == ON_DISPOSE) { + this.onCloseSink.tryEmitEmpty(); + } else { + this.onCloseSink.tryEmitError(t); + } + } + + @Override + protected void doOnValueExpired(RSocket value) { + value.dispose(); + } + + @Override + protected void doOnDispose() { + Operators.terminate(S, this); + + final RSocket value = this.value; + if (value != null) { + value.onClose().subscribe(null, onCloseSink::tryEmitError, onCloseSink::tryEmitEmpty); + } else { + onCloseSink.tryEmitEmpty(); + } + } + + @Override + public Mono fireAndForget(Payload payload) { + return new MonoInner<>(this, payload, FrameType.REQUEST_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return new MonoInner<>(this, payload, FrameType.REQUEST_RESPONSE); + } + + @Override + public Flux requestStream(Payload payload) { + return new FluxInner<>(this, payload, FrameType.REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new FluxInner<>(this, payloads, FrameType.REQUEST_CHANNEL); + } + + @Override + public Mono metadataPush(Payload payload) { + return new MonoInner<>(this, payload, FrameType.METADATA_PUSH); + } + + LoadbalanceTarget target() { + return this.loadbalanceTarget; + } + + @Override + public Mono onClose() { + return this.onCloseSink.asMono(); + } + + @Override + public double availability() { + final RSocket socket = valueIfResolved(); + return socket != null ? socket.availability() : 0.0d; + } + + static final class MonoInner extends MonoDeferredResolution { + + MonoInner(PooledRSocket parent, Payload payload, FrameType requestType) { + super(parent, payload, requestType); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void accept(RSocket rSocket, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + ReferenceCountUtil.safeRelease(this.payload); + onError(t); + return; + } + + if (rSocket != null) { + Mono source; + switch (this.requestType) { + case REQUEST_FNF: + source = rSocket.fireAndForget(this.payload); + break; + case REQUEST_RESPONSE: + source = rSocket.requestResponse(this.payload); + break; + case METADATA_PUSH: + source = rSocket.metadataPush(this.payload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe((CoreSubscriber) this); + } else { + parent.observe(this); + } + } + } + + static final class FluxInner extends FluxDeferredResolution { + + FluxInner(PooledRSocket parent, INPUT fluxOrPayload, FrameType requestType) { + super(parent, fluxOrPayload, requestType); + } + + @Override + @SuppressWarnings("unchecked") + public void accept(RSocket rSocket, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(t); + return; + } + + if (rSocket != null) { + Flux source; + switch (this.requestType) { + case REQUEST_STREAM: + source = rSocket.requestStream((Payload) this.fluxOrPayload); + break; + case REQUEST_CHANNEL: + source = rSocket.requestChannel((Flux) this.fluxOrPayload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe(this); + } else { + parent.observe(this); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/Quantile.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/Quantile.java new file mode 100644 index 000000000..84c699197 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/Quantile.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +interface Quantile { + /** @return the estimation of the current value of the quantile */ + double estimation(); + + /** + * Insert a data point `x` in the quantile estimator. + * + * @param x the data point to add. + */ + void insert(double x); +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java new file mode 100644 index 000000000..59d9678d0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/RSocketPool.java @@ -0,0 +1,532 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.frame.FrameType; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.ListIterator; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.stream.Collectors; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +class RSocketPool extends ResolvingOperator + implements CoreSubscriber>, Closeable { + + static final AtomicReferenceFieldUpdater ACTIVE_SOCKETS = + AtomicReferenceFieldUpdater.newUpdater( + RSocketPool.class, PooledRSocket[].class, "activeSockets"); + static final PooledRSocket[] EMPTY = new PooledRSocket[0]; + static final PooledRSocket[] TERMINATED = new PooledRSocket[0]; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(RSocketPool.class, Subscription.class, "s"); + final DeferredResolutionRSocket deferredResolutionRSocket = new DeferredResolutionRSocket(this); + final RSocketConnector connector; + final LoadbalanceStrategy loadbalanceStrategy; + final Sinks.Empty onAllClosedSink = Sinks.unsafe().empty(); + volatile PooledRSocket[] activeSockets; + volatile Subscription s; + + public RSocketPool( + RSocketConnector connector, + Publisher> targetPublisher, + LoadbalanceStrategy loadbalanceStrategy) { + this.connector = connector; + this.loadbalanceStrategy = loadbalanceStrategy; + + ACTIVE_SOCKETS.lazySet(this, EMPTY); + + targetPublisher.subscribe(this); + } + + @Override + public Mono onClose() { + return onAllClosedSink.asMono(); + } + + @Override + protected void doOnDispose() { + Operators.terminate(S, this); + + RSocket[] activeSockets = ACTIVE_SOCKETS.getAndSet(this, TERMINATED); + for (RSocket rSocket : activeSockets) { + rSocket.dispose(); + } + + if (activeSockets.length > 0) { + Mono.whenDelayError( + Arrays.stream(activeSockets).map(RSocket::onClose).collect(Collectors.toList())) + .subscribe(null, onAllClosedSink::tryEmitError, onAllClosedSink::tryEmitEmpty); + } else { + onAllClosedSink.tryEmitEmpty(); + } + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(List targets) { + if (isDisposed()) { + return; + } + + // This operation should happen less frequently than calls to select() (which are per request) + // and therefore it is acceptable somewhat less efficient. + + PooledRSocket[] previouslyActiveSockets; + PooledRSocket[] inactiveSockets; + PooledRSocket[] socketsToUse; + for (; ; ) { + HashMap rSocketSuppliersCopy = new HashMap<>(targets.size()); + + int j = 0; + for (LoadbalanceTarget target : targets) { + rSocketSuppliersCopy.put(target, j++); + } + + // Intersect current and new list of targets and find the ones to keep vs dispose + previouslyActiveSockets = this.activeSockets; + inactiveSockets = new PooledRSocket[previouslyActiveSockets.length]; + PooledRSocket[] nextActiveSockets = + new PooledRSocket[previouslyActiveSockets.length + rSocketSuppliersCopy.size()]; + int activeSocketsPosition = 0; + int inactiveSocketsPosition = 0; + for (int i = 0; i < previouslyActiveSockets.length; i++) { + PooledRSocket rSocket = previouslyActiveSockets[i]; + + Integer index = rSocketSuppliersCopy.remove(rSocket.target()); + if (index == null) { + // if one of the active rSockets is not included, we remove it and put in the + // pending removal + if (!rSocket.isDisposed()) { + inactiveSockets[inactiveSocketsPosition++] = rSocket; + // TODO: provide a meaningful algo for keeping removed rsocket in the list + // nextActiveSockets[position++] = rSocket; + } + } else { + if (!rSocket.isDisposed()) { + // keep old RSocket instance + nextActiveSockets[activeSocketsPosition++] = rSocket; + } else { + // put newly create RSocket instance + LoadbalanceTarget target = targets.get(index); + nextActiveSockets[activeSocketsPosition++] = + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); + } + } + } + + // The remainder are the brand new targets + for (LoadbalanceTarget target : rSocketSuppliersCopy.keySet()) { + nextActiveSockets[activeSocketsPosition++] = + new PooledRSocket(this, this.connector.connect(target.getTransport()), target); + } + + if (activeSocketsPosition == 0) { + socketsToUse = EMPTY; + } else { + socketsToUse = Arrays.copyOf(nextActiveSockets, activeSocketsPosition); + } + if (ACTIVE_SOCKETS.compareAndSet(this, previouslyActiveSockets, socketsToUse)) { + break; + } + } + + for (PooledRSocket inactiveSocket : inactiveSockets) { + if (inactiveSocket == null) { + break; + } + + inactiveSocket.dispose(); + } + + if (isPending()) { + // notifies that upstream is resolved + if (socketsToUse != EMPTY) { + //noinspection ConstantConditions + complete(this); + } + } + } + + @Override + public void onError(Throwable t) { + // indicates upstream termination + S.set(this, Operators.cancelledSubscription()); + // propagates error and terminates the whole pool + terminate(t); + } + + @Override + public void onComplete() { + // indicates upstream termination + S.set(this, Operators.cancelledSubscription()); + } + + RSocket select() { + if (isDisposed()) { + return this.deferredResolutionRSocket; + } + + RSocket selected = doSelect(); + + if (selected == null) { + if (this.s == Operators.cancelledSubscription()) { + terminate(new CancellationException("Pool is exhausted")); + } else { + invalidate(); + + // check since it is possible that between doSelect() and invalidate() we might + // have received new sockets + selected = doSelect(); + if (selected != null) { + return selected; + } + } + return this.deferredResolutionRSocket; + } + + return selected; + } + + @Nullable + RSocket doSelect() { + PooledRSocket[] sockets = this.activeSockets; + + if (sockets == EMPTY || sockets == TERMINATED) { + return null; + } + + return this.loadbalanceStrategy.select(WrappingList.wrap(sockets)); + } + + static class DeferredResolutionRSocket implements RSocket { + + final RSocketPool parent; + + DeferredResolutionRSocket(RSocketPool parent) { + this.parent = parent; + } + + @Override + public Mono fireAndForget(Payload payload) { + return new MonoInner<>(this.parent, payload, FrameType.REQUEST_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return new MonoInner<>(this.parent, payload, FrameType.REQUEST_RESPONSE); + } + + @Override + public Flux requestStream(Payload payload) { + return new FluxInner<>(this.parent, payload, FrameType.REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return new FluxInner<>(this.parent, payloads, FrameType.REQUEST_CHANNEL); + } + + @Override + public Mono metadataPush(Payload payload) { + return new MonoInner<>(this.parent, payload, FrameType.METADATA_PUSH); + } + } + + static final class MonoInner extends MonoDeferredResolution { + + MonoInner(RSocketPool parent, Payload payload, FrameType requestType) { + super(parent, payload, requestType); + } + + @Override + @SuppressWarnings({"unchecked", "rawtypes"}) + public void accept(Object aVoid, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + ReferenceCountUtil.safeRelease(this.payload); + onError(t); + return; + } + + RSocketPool parent = (RSocketPool) this.parent; + for (; ; ) { + RSocket rSocket = parent.doSelect(); + if (rSocket != null) { + Mono source; + switch (this.requestType) { + case REQUEST_FNF: + source = rSocket.fireAndForget(this.payload); + break; + case REQUEST_RESPONSE: + source = rSocket.requestResponse(this.payload); + break; + case METADATA_PUSH: + source = rSocket.metadataPush(this.payload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe((CoreSubscriber) this); + + return; + } + + final int state = parent.add(this); + + if (state == ADDED_STATE) { + return; + } + + if (state == TERMINATED_STATE) { + final Throwable error = parent.t; + ReferenceCountUtil.safeRelease(this.payload); + onError(error); + return; + } + } + } + } + + static final class FluxInner extends FluxDeferredResolution { + + FluxInner(RSocketPool parent, INPUT fluxOrPayload, FrameType requestType) { + super(parent, fluxOrPayload, requestType); + } + + @Override + @SuppressWarnings("unchecked") + public void accept(Object aVoid, Throwable t) { + if (isTerminated()) { + return; + } + + if (t != null) { + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(t); + return; + } + + RSocketPool parent = (RSocketPool) this.parent; + for (; ; ) { + RSocket rSocket = parent.doSelect(); + if (rSocket != null) { + Flux source; + switch (this.requestType) { + case REQUEST_STREAM: + source = rSocket.requestStream((Payload) this.fluxOrPayload); + break; + case REQUEST_CHANNEL: + source = rSocket.requestChannel((Flux) this.fluxOrPayload); + break; + default: + Operators.error(this.actual, new IllegalStateException("Should never happen")); + return; + } + + source.subscribe(this); + + return; + } + + final int state = parent.add(this); + + if (state == ADDED_STATE) { + return; + } + + if (state == TERMINATED_STATE) { + final Throwable error = parent.t; + if (this.requestType == FrameType.REQUEST_STREAM) { + ReferenceCountUtil.safeRelease(this.fluxOrPayload); + } + onError(error); + return; + } + } + } + } + + static final class WrappingList implements List { + + static final ThreadLocal INSTANCE = ThreadLocal.withInitial(WrappingList::new); + + private PooledRSocket[] activeSockets; + + static List wrap(PooledRSocket[] activeSockets) { + final WrappingList sockets = INSTANCE.get(); + sockets.activeSockets = activeSockets; + return sockets; + } + + @Override + public RSocket get(int index) { + final PooledRSocket socket = activeSockets[index]; + + RSocket realValue = socket.value; + if (realValue != null) { + return realValue; + } + + realValue = socket.valueIfResolved(); + if (realValue != null) { + return realValue; + } + + return socket; + } + + @Override + public int size() { + return activeSockets.length; + } + + @Override + public boolean isEmpty() { + return activeSockets.length == 0; + } + + @Override + public Object[] toArray() { + return activeSockets; + } + + @Override + @SuppressWarnings("unchecked") + public T[] toArray(T[] a) { + return (T[]) activeSockets; + } + + @Override + public boolean contains(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean add(RSocket weightedRSocket) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean containsAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean addAll(int index, Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean removeAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean retainAll(Collection c) { + throw new UnsupportedOperationException(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(); + } + + @Override + public RSocket set(int index, RSocket element) { + throw new UnsupportedOperationException(); + } + + @Override + public void add(int index, RSocket element) { + throw new UnsupportedOperationException(); + } + + @Override + public RSocket remove(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public int indexOf(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public int lastIndexOf(Object o) { + throw new UnsupportedOperationException(); + } + + @Override + public ListIterator listIterator() { + throw new UnsupportedOperationException(); + } + + @Override + public ListIterator listIterator(int index) { + throw new UnsupportedOperationException(); + } + + @Override + public List subList(int fromIndex, int toIndex) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java new file mode 100644 index 000000000..52f16e166 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/ResolvingOperator.java @@ -0,0 +1,420 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import java.time.Duration; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BiConsumer; +import reactor.core.Disposable; +import reactor.core.Exceptions; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +// This class is a copy of the same class in io.rsocket.core + +class ResolvingOperator implements Disposable { + + static final CancellationException ON_DISPOSE = new CancellationException("Disposed"); + + volatile int wip; + + @SuppressWarnings("rawtypes") + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ResolvingOperator.class, "wip"); + + volatile BiConsumer[] subscribers; + + @SuppressWarnings("rawtypes") + static final AtomicReferenceFieldUpdater SUBSCRIBERS = + AtomicReferenceFieldUpdater.newUpdater( + ResolvingOperator.class, BiConsumer[].class, "subscribers"); + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_UNSUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] EMPTY_SUBSCRIBED = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] READY = new BiConsumer[0]; + + @SuppressWarnings("unchecked") + static final BiConsumer[] TERMINATED = new BiConsumer[0]; + + static final int ADDED_STATE = 0; + static final int READY_STATE = 1; + static final int TERMINATED_STATE = 2; + + T value; + Throwable t; + + public ResolvingOperator() { + + SUBSCRIBERS.lazySet(this, EMPTY_UNSUBSCRIBED); + } + + @Override + public final void dispose() { + this.terminate(ON_DISPOSE); + } + + @Override + public final boolean isDisposed() { + return this.subscribers == TERMINATED; + } + + public final boolean isPending() { + BiConsumer[] state = this.subscribers; + return state != READY && state != TERMINATED; + } + + @Nullable + public final T valueIfResolved() { + if (this.subscribers == READY) { + T value = this.value; + if (value != null) { + return value; + } + } + + return null; + } + + final void observe(BiConsumer actual) { + for (; ; ) { + final int state = this.add(actual); + + T value = this.value; + + if (state == READY_STATE) { + if (value != null) { + actual.accept(value, null); + return; + } + // value == null means racing between invalidate and this subscriber + // thus, we have to loop again + continue; + } else if (state == TERMINATED_STATE) { + actual.accept(null, this.t); + return; + } + + return; + } + } + + /** + * Block the calling thread for the specified time, waiting for the completion of this {@code + * ReconnectMono}. If the {@link ResolvingOperator} is completed with an error a RuntimeException + * that wraps the error is thrown. + * + * @param timeout the timeout value as a {@link Duration} + * @return the value of this {@link ResolvingOperator} or {@code null} if the timeout is reached + * and the {@link ResolvingOperator} has not completed + * @throws RuntimeException if terminated with error + * @throws IllegalStateException if timed out or {@link Thread} was interrupted with {@link + * InterruptedException} + */ + @Nullable + @SuppressWarnings({"uncheked", "BusyWait"}) + public T block(@Nullable Duration timeout) { + try { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + + // connect once + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + long delay; + if (null == timeout) { + delay = 0L; + } else { + delay = System.nanoTime() + timeout.toNanos(); + } + for (; ; ) { + subscribers = this.subscribers; + + if (subscribers == READY) { + final T value = this.value; + if (value != null) { + return value; + } else { + // value == null means racing between invalidate and this block + // thus, we have to update the state again and see what happened + subscribers = this.subscribers; + } + } + if (subscribers == TERMINATED) { + RuntimeException re = Exceptions.propagate(this.t); + re = Exceptions.addSuppressed(re, new Exception("Terminated with an error")); + throw re; + } + if (timeout != null && delay < System.nanoTime()) { + throw new IllegalStateException("Timeout on Mono blocking read"); + } + + // connect again since invalidate() has happened in between + if (subscribers == EMPTY_UNSUBSCRIBED + && SUBSCRIBERS.compareAndSet(this, EMPTY_UNSUBSCRIBED, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + } + + Thread.sleep(1); + } + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + + throw new IllegalStateException("Thread Interruption on Mono blocking read"); + } + } + + @SuppressWarnings("unchecked") + final void terminate(Throwable t) { + if (isDisposed()) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + // writes happens before volatile write + this.t = t; + + final BiConsumer[] subscribers = SUBSCRIBERS.getAndSet(this, TERMINATED); + if (subscribers == TERMINATED) { + Operators.onErrorDropped(t, Context.empty()); + return; + } + + this.doOnDispose(); + + this.doFinally(); + + for (BiConsumer consumer : subscribers) { + consumer.accept(null, t); + } + } + + final void complete(T value) { + BiConsumer[] subscribers = this.subscribers; + if (subscribers == TERMINATED) { + this.doOnValueExpired(value); + return; + } + + this.value = value; + + for (; ; ) { + // ensures TERMINATE is going to be replaced with READY + if (SUBSCRIBERS.compareAndSet(this, subscribers, READY)) { + break; + } + + subscribers = this.subscribers; + + if (subscribers == TERMINATED) { + this.doFinally(); + return; + } + } + + this.doOnValueResolved(value); + + for (BiConsumer consumer : subscribers) { + consumer.accept(value, null); + } + } + + protected void doOnValueResolved(T value) { + // no ops + } + + final void doFinally() { + if (WIP.getAndIncrement(this) != 0) { + return; + } + + int m = 1; + T value; + + for (; ; ) { + value = this.value; + if (value != null && isDisposed()) { + this.value = null; + this.doOnValueExpired(value); + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + return; + } + } + } + + final void invalidate() { + if (this.subscribers == TERMINATED) { + return; + } + + final BiConsumer[] subscribers = this.subscribers; + + if (subscribers == READY) { + // guarded section to ensure we expire value exactly once if there is racing + if (WIP.getAndIncrement(this) != 0) { + return; + } + + final T value = this.value; + if (value != null) { + this.value = null; + this.doOnValueExpired(value); + } + + int m = 1; + for (; ; ) { + if (isDisposed()) { + return; + } + + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + } + + SUBSCRIBERS.compareAndSet(this, READY, EMPTY_UNSUBSCRIBED); + } + } + + protected void doOnValueExpired(T value) { + // no ops + } + + protected void doOnDispose() { + // no ops + } + + public final boolean connect() { + for (; ; ) { + final BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return false; + } + + if (a == READY) { + return true; + } + + if (a != EMPTY_UNSUBSCRIBED) { + // do nothing if already started + return true; + } + + if (SUBSCRIBERS.compareAndSet(this, a, EMPTY_SUBSCRIBED)) { + this.doSubscribe(); + return true; + } + } + } + + final int add(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + + if (a == TERMINATED) { + return TERMINATED_STATE; + } + + if (a == READY) { + return READY_STATE; + } + + int n = a.length; + @SuppressWarnings("unchecked") + BiConsumer[] b = new BiConsumer[n + 1]; + System.arraycopy(a, 0, b, 0, n); + b[n] = ps; + + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + if (a == EMPTY_UNSUBSCRIBED) { + this.doSubscribe(); + } + return ADDED_STATE; + } + } + } + + protected void doSubscribe() { + // no ops + } + + @SuppressWarnings("unchecked") + final void remove(BiConsumer ps) { + for (; ; ) { + BiConsumer[] a = this.subscribers; + int n = a.length; + if (n == 0) { + return; + } + + int j = -1; + for (int i = 0; i < n; i++) { + if (a[i] == ps) { + j = i; + break; + } + } + + if (j < 0) { + return; + } + + BiConsumer[] b; + + if (n == 1) { + b = EMPTY_SUBSCRIBED; + } else { + b = new BiConsumer[n - 1]; + System.arraycopy(a, 0, b, 0, j); + System.arraycopy(a, j + 1, b, j, n - j - 1); + } + if (SUBSCRIBERS.compareAndSet(this, a, b)) { + return; + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java new file mode 100644 index 000000000..f1a9f8c55 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategy.java @@ -0,0 +1,42 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import java.util.List; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Simple {@link LoadbalanceStrategy} that selects the {@code RSocket} to use in round-robin order. + * + * @since 1.1 + */ +public class RoundRobinLoadbalanceStrategy implements LoadbalanceStrategy { + + volatile int nextIndex; + + private static final AtomicIntegerFieldUpdater NEXT_INDEX = + AtomicIntegerFieldUpdater.newUpdater(RoundRobinLoadbalanceStrategy.class, "nextIndex"); + + @Override + public RSocket select(List sockets) { + int length = sockets.size(); + + int indexToUse = Math.abs(NEXT_INDEX.getAndIncrement(this) % length); + + return sockets.get(indexToUse); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java new file mode 100644 index 000000000..c30c8ad6b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategy.java @@ -0,0 +1,249 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.plugins.RequestInterceptor; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Function; +import reactor.util.annotation.Nullable; + +/** + * {@link LoadbalanceStrategy} that assigns a weight to each {@code RSocket} based on {@link + * RSocket#availability() availability} and usage statistics. The weight is used to decide which + * {@code RSocket} to select. + * + *

Use {@link #create()} or a {@link #builder() Builder} to create an instance. + * + * @since 1.1 + * @see Predictive Load-Balancing: Unfair but + * Faster & more Robust + * @see WeightedStatsRequestInterceptor + */ +public class WeightedLoadbalanceStrategy implements ClientLoadbalanceStrategy { + + private static final double EXP_FACTOR = 4.0; + + final int maxPairSelectionAttempts; + final Function weightedStatsResolver; + + private WeightedLoadbalanceStrategy( + int numberOfAttempts, @Nullable Function resolver) { + this.maxPairSelectionAttempts = numberOfAttempts; + this.weightedStatsResolver = (resolver != null ? resolver : new DefaultWeightedStatsResolver()); + } + + @Override + public void initialize(RSocketConnector connector) { + final Function resolver = weightedStatsResolver; + if (resolver instanceof DefaultWeightedStatsResolver) { + ((DefaultWeightedStatsResolver) resolver).init(connector); + } + } + + @Override + public RSocket select(List sockets) { + final int size = sockets.size(); + + RSocket weightedRSocket; + final Function weightedStatsResolver = this.weightedStatsResolver; + switch (size) { + case 1: + weightedRSocket = sockets.get(0); + break; + case 2: + { + RSocket rsc1 = sockets.get(0); + RSocket rsc2 = sockets.get(1); + + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); + if (w1 < w2) { + weightedRSocket = rsc2; + } else { + weightedRSocket = rsc1; + } + } + break; + default: + { + RSocket rsc1 = null; + RSocket rsc2 = null; + + for (int i = 0; i < this.maxPairSelectionAttempts; i++) { + int i1 = ThreadLocalRandom.current().nextInt(size); + int i2 = ThreadLocalRandom.current().nextInt(size - 1); + + if (i2 >= i1) { + i2++; + } + rsc1 = sockets.get(i1); + rsc2 = sockets.get(i2); + if (rsc1.availability() > 0.0 && rsc2.availability() > 0.0) { + break; + } + } + + if (rsc1 != null & rsc2 != null) { + double w1 = algorithmicWeight(rsc1, weightedStatsResolver.apply(rsc1)); + double w2 = algorithmicWeight(rsc2, weightedStatsResolver.apply(rsc2)); + + if (w1 < w2) { + weightedRSocket = rsc2; + } else { + weightedRSocket = rsc1; + } + } else if (rsc1 != null) { + weightedRSocket = rsc1; + } else { + weightedRSocket = rsc2; + } + } + } + + return weightedRSocket; + } + + private static double algorithmicWeight( + RSocket rSocket, @Nullable final WeightedStats weightedStats) { + if (weightedStats == null) { + return 1.0; + } + if (rSocket.isDisposed() || rSocket.availability() == 0.0) { + return 0.0; + } + final int pending = weightedStats.pending(); + + double latency = weightedStats.predictedLatency(); + + final double low = weightedStats.lowerQuantileLatency(); + final double high = + Math.max( + weightedStats.higherQuantileLatency(), + low * 1.001); // ensure higherQuantile > lowerQuantile + .1% + final double bandWidth = Math.max(high - low, 1); + + if (latency < low) { + latency /= calculateFactor(low, latency, bandWidth); + } else if (latency > high) { + latency *= calculateFactor(latency, high, bandWidth); + } + + return (rSocket.availability() * weightedStats.weightedAvailability()) + / (1.0d + latency * (pending + 1)); + } + + private static double calculateFactor(final double u, final double l, final double bandWidth) { + final double alpha = (u - l) / bandWidth; + return Math.pow(1 + alpha, EXP_FACTOR); + } + + /** + * Create an instance of {@link WeightedLoadbalanceStrategy} with default settings, which include + * round-robin load-balancing and 5 {@link #maxPairSelectionAttempts}. + */ + public static WeightedLoadbalanceStrategy create() { + return new Builder().build(); + } + + /** Return a builder to create a {@link WeightedLoadbalanceStrategy} with. */ + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link WeightedLoadbalanceStrategy}. */ + public static class Builder { + + private int maxPairSelectionAttempts = 5; + + @Nullable private Function weightedStatsResolver; + + private Builder() {} + + /** + * How many times to try to randomly select a pair of RSocket connections with non-zero + * availability. This is applicable when there are more than two connections in the pool. If the + * number of attempts is exceeded, the last selected pair is used. + * + *

By default this is set to 5. + * + * @param numberOfAttempts the iteration count + */ + public Builder maxPairSelectionAttempts(int numberOfAttempts) { + this.maxPairSelectionAttempts = numberOfAttempts; + return this; + } + + /** + * Configure how the created {@link WeightedLoadbalanceStrategy} should find the stats for a + * given RSocket. + * + *

By default this resolver is not set. + * + *

When {@code WeightedLoadbalanceStrategy} is used through the {@link + * LoadbalanceRSocketClient}, the resolver does not need to be set because a {@link + * WeightedStatsRequestInterceptor} is automatically installed through the {@link + * ClientLoadbalanceStrategy} callback. If this strategy is used in any other context however, a + * resolver here must be provided. + * + * @param resolver to find the stats for an RSocket with + */ + public Builder weightedStatsResolver(Function resolver) { + this.weightedStatsResolver = resolver; + return this; + } + + /** Build the {@code WeightedLoadbalanceStrategy} instance. */ + public WeightedLoadbalanceStrategy build() { + return new WeightedLoadbalanceStrategy( + this.maxPairSelectionAttempts, this.weightedStatsResolver); + } + } + + private static class DefaultWeightedStatsResolver implements Function { + + final Map statsMap = new ConcurrentHashMap<>(); + + @Override + public WeightedStats apply(RSocket rSocket) { + return statsMap.get(rSocket); + } + + void init(RSocketConnector connector) { + connector.interceptors( + registry -> + registry.forRequestsInRequester( + (Function) + rSocket -> { + final WeightedStatsRequestInterceptor interceptor = + new WeightedStatsRequestInterceptor() { + @Override + public void dispose() { + statsMap.remove(rSocket); + } + }; + statsMap.put(rSocket, interceptor); + + return interceptor; + })); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java new file mode 100644 index 000000000..5ebe668ce --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStats.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; + +/** + * Contract to expose the stats required in {@link WeightedLoadbalanceStrategy} to calculate an + * algorithmic weight for an {@code RSocket}. The weight helps to select an {@code RSocket} for + * load-balancing. + * + * @since 1.1 + */ +public interface WeightedStats { + + double higherQuantileLatency(); + + double lowerQuantileLatency(); + + int pending(); + + double predictedLatency(); + + double weightedAvailability(); + + /** + * Create a proxy for the given {@code RSocket} that attaches the stats contained in this instance + * and exposes them as {@link WeightedStats}. + * + * @param rsocket the RSocket to wrap + * @return the wrapped RSocket + * @since 1.1.1 + */ + default RSocket wrap(RSocket rsocket) { + return new WeightedStatsRSocketProxy(rsocket, this); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java new file mode 100644 index 000000000..f2cf3fbd0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRSocketProxy.java @@ -0,0 +1,62 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.RSocket; +import io.rsocket.util.RSocketProxy; + +/** + * Package private {@code RSocketProxy} used from {@link WeightedStats#wrap(RSocket)} to attach a + * {@link WeightedStats} instance to an {@code RSocket}. + */ +class WeightedStatsRSocketProxy extends RSocketProxy implements WeightedStats { + + private final WeightedStats weightedStats; + + public WeightedStatsRSocketProxy(RSocket source, WeightedStats weightedStats) { + super(source); + this.weightedStats = weightedStats; + } + + @Override + public double higherQuantileLatency() { + return this.weightedStats.higherQuantileLatency(); + } + + @Override + public double lowerQuantileLatency() { + return this.weightedStats.lowerQuantileLatency(); + } + + @Override + public int pending() { + return this.weightedStats.pending(); + } + + @Override + public double predictedLatency() { + return this.weightedStats.predictedLatency(); + } + + @Override + public double weightedAvailability() { + return this.weightedStats.weightedAvailability(); + } + + public WeightedStats getDelegate() { + return this.weightedStats; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java new file mode 100644 index 000000000..ec2c88b19 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/WeightedStatsRequestInterceptor.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import reactor.util.annotation.Nullable; + +/** + * {@link RequestInterceptor} that hooks into request lifecycle and calls methods of the parent + * class to manage tracking state and expose {@link WeightedStats}. + * + *

This interceptor the default mechanism for gathering stats when {@link + * WeightedLoadbalanceStrategy} is used with {@link LoadbalanceRSocketClient}. + * + * @since 1.1 + * @see LoadbalanceRSocketClient + * @see WeightedLoadbalanceStrategy + */ +public class WeightedStatsRequestInterceptor extends BaseWeightedStats + implements RequestInterceptor { + + final Int2LongHashMap requestsStartTime = new Int2LongHashMap(-1); + + public WeightedStatsRequestInterceptor() { + super(); + } + + @Override + public final void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + final long startTime = startRequest(); + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + requestsStartTime.put(streamId, startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + this.startStream(); + } + } + + @Override + public final void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + long endTime = stopRequest(startTime); + if (t == null) { + record(endTime - startTime); + } + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + + if (t != null) { + updateAvailability(0.0d); + } else { + updateAvailability(1.0d); + } + } + + @Override + public final void onCancel(int streamId, FrameType requestType) { + switch (requestType) { + case REQUEST_FNF: + case REQUEST_RESPONSE: + long startTime; + final Int2LongHashMap requestsStartTime = this.requestsStartTime; + synchronized (requestsStartTime) { + startTime = requestsStartTime.remove(streamId); + } + stopRequest(startTime); + break; + case REQUEST_STREAM: + case REQUEST_CHANNEL: + stopStream(); + break; + } + } + + @Override + public final void onReject(Throwable rejectionReason, FrameType requestType, ByteBuf metadata) {} + + @Override + public void dispose() {} +} diff --git a/src/main/java/io/reactivesocket/exceptions/UnsupportedSetupException.java b/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java similarity index 65% rename from src/main/java/io/reactivesocket/exceptions/UnsupportedSetupException.java rename to rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java index 8647cf638..19668e99c 100644 --- a/src/main/java/io/reactivesocket/exceptions/UnsupportedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/loadbalance/package-info.java @@ -1,11 +1,11 @@ -/** - * Copyright 2015 Netflix, Inc. +/* + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket.exceptions; -public class UnsupportedSetupException extends SetupException { - public UnsupportedSetupException(String message) { - super(message); - } -} +/** Support client load-balancing in RSocket Java. */ +@NonNullApi +package io.rsocket.loadbalance; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java new file mode 100644 index 000000000..c16c4dc52 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/AuthMetadataCodec.java @@ -0,0 +1,334 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.util.CharByteBufUtil; + +public class AuthMetadataCodec { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + static final int USERNAME_BYTES_LENGTH = 2; + static final int AUTH_TYPE_ID_LENGTH = 1; + + static final char[] EMPTY_CHARS_ARRAY = new char[0]; + + private AuthMetadataCodec() {} + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customAuthType the custom mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code customAuthType} is non US_ASCII string or + * empty string or its length is greater than 128 bytes + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, String customAuthType, ByteBuf metadata) { + + int actualASCIILength = ByteBufUtil.utf8Bytes(customAuthType); + if (actualASCIILength != customAuthType.length()) { + throw new IllegalArgumentException("custom auth type must be US_ASCII characters only"); + } + if (actualASCIILength < 1 || actualASCIILength > 128) { + throw new IllegalArgumentException( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + int capacity = 1 + actualASCIILength; + ByteBuf headerBuffer = allocator.buffer(capacity, capacity); + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + headerBuffer.writeByte(actualASCIILength - 1); + + ByteBufUtil.reserveAndWriteUtf8(headerBuffer, customAuthType, actualASCIILength); + + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using custom authentication type + * + * @param allocator the {@link ByteBufAllocator} to create intermediate buffers as needed. + * @param authType the well-known mime type to encode. + * @param metadata the metadata value to encode. + * @throws IllegalArgumentException in case of {@code authType} is {@link + * WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} or {@link + * WellKnownAuthType#UNKNOWN_RESERVED_AUTH_TYPE} + */ + public static ByteBuf encodeMetadata( + ByteBufAllocator allocator, WellKnownAuthType authType, ByteBuf metadata) { + + if (authType == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE + || authType == WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE) { + throw new IllegalArgumentException("only allowed AuthType should be used"); + } + + int capacity = AUTH_TYPE_ID_LENGTH; + ByteBuf headerBuffer = + allocator + .buffer(capacity, capacity) + .writeByte(authType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + return allocator.compositeBuffer(2).addComponents(true, headerBuffer, metadata); + } + + /** + * Encode a Authentication CompositeMetadata payload using Simple Authentication format + * + * @throws IllegalArgumentException if the username length is greater than 65535 + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param username the char sequence which represents user name. + * @param password the char sequence which represents user password. + */ + public static ByteBuf encodeSimpleMetadata( + ByteBufAllocator allocator, char[] username, char[] password) { + + int usernameLength = CharByteBufUtil.utf8Bytes(username); + if (usernameLength > 65535) { + throw new IllegalArgumentException( + "Username should be shorter than or equal to 65535 bytes length in UTF-8 encoding"); + } + + int passwordLength = CharByteBufUtil.utf8Bytes(password); + int capacity = AUTH_TYPE_ID_LENGTH + USERNAME_BYTES_LENGTH + usernameLength + passwordLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.SIMPLE.getIdentifier() | STREAM_METADATA_KNOWN_MASK) + .writeShort(usernameLength); + + CharByteBufUtil.writeUtf8(buffer, username); + CharByteBufUtil.writeUtf8(buffer, password); + + return buffer; + } + + /** + * Encode a Authentication CompositeMetadata payload using Bearer Authentication format + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param token the char sequence which represents BEARER token. + */ + public static ByteBuf encodeBearerMetadata(ByteBufAllocator allocator, char[] token) { + + int tokenLength = CharByteBufUtil.utf8Bytes(token); + int capacity = AUTH_TYPE_ID_LENGTH + tokenLength; + final ByteBuf buffer = + allocator + .buffer(capacity, capacity) + .writeByte(WellKnownAuthType.BEARER.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + + CharByteBufUtil.writeUtf8(buffer, token); + + return buffer; + } + + /** + * Encode a new Authentication Metadata payload information, first verifying if the passed {@link + * String} matches a {@link WellKnownAuthType} (in which case it will be encoded in a compressed + * fashion using the mime id of that type). + * + *

Prefer using {@link #encodeMetadata(ByteBufAllocator, String, ByteBuf)} if you already know + * that the mime type is not a {@link WellKnownAuthType}. + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param authType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeMetadata(ByteBufAllocator, WellKnownAuthType, ByteBuf) + * @see #encodeMetadata(ByteBufAllocator, String, ByteBuf) + */ + public static ByteBuf encodeMetadataWithCompression( + ByteBufAllocator allocator, String authType, ByteBuf metadata) { + WellKnownAuthType wkn = WellKnownAuthType.fromString(authType); + if (wkn == WellKnownAuthType.UNPARSEABLE_AUTH_TYPE) { + return AuthMetadataCodec.encodeMetadata(allocator, authType, metadata); + } else { + return AuthMetadataCodec.encodeMetadata(allocator, wkn, metadata); + } + } + + /** + * Get the first {@code byte} from a {@link ByteBuf} and check whether it is length or {@link + * WellKnownAuthType}. Assuming said buffer properly contains such a {@code byte} + * + * @param metadata byteBuf used to get information from + */ + public static boolean isWellKnownAuthType(ByteBuf metadata) { + byte lengthOrId = metadata.getByte(0); + return (lengthOrId & STREAM_METADATA_LENGTH_MASK) != lengthOrId; + } + + /** + * Read first byte from the given {@code metadata} and tries to convert it's value to {@link + * WellKnownAuthType}. + * + * @param metadata given metadata buffer to read from + * @return Return on of the know Auth types or {@link WellKnownAuthType#UNPARSEABLE_AUTH_TYPE} if + * field's value is length or unknown auth type + * @throws IllegalStateException if not enough readable bytes in the given {@link ByteBuf} + */ + public static WellKnownAuthType readWellKnownAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 1) { + throw new IllegalStateException( + "Unable to decode Well Know Auth type. Not enough readable bytes"); + } + byte lengthOrId = metadata.readByte(); + int normalizedId = (byte) (lengthOrId & STREAM_METADATA_LENGTH_MASK); + + if (normalizedId != lengthOrId) { + return WellKnownAuthType.fromIdentifier(normalizedId); + } + + return WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + } + + /** + * Read up to 129 bytes from the given metadata in order to get the custom Auth Type + * + * @param metadata + * @return + */ + public static CharSequence readCustomAuthType(ByteBuf metadata) { + if (metadata.readableBytes() < 2) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Not enough readable bytes"); + } + + byte encodedLength = metadata.readByte(); + if (encodedLength < 0) { + throw new IllegalStateException( + "Unable to decode custom Auth type. Incorrect auth type length"); + } + + // encoded length is realLength - 1 in order to avoid intersection with 0x00 authtype + int realLength = encodedLength + 1; + if (metadata.readableBytes() < realLength) { + throw new IllegalArgumentException( + "Unable to decode custom Auth type. Malformed length or auth type string"); + } + + return metadata.readCharSequence(realLength, CharsetUtil.US_ASCII); + } + + /** + * Read all remaining {@code bytes} from the given {@link ByteBuf} and return sliced + * representation of a payload + * + * @param metadata metadata to get payload from. Please note, the {@code metadata#readIndex} + * should be set to the beginning of the payload bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if no bytes readable in the + * given one + */ + public static ByteBuf readPayload(ByteBuf metadata) { + if (metadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return metadata.readSlice(metadata.readableBytes()); + } + + /** + * Read up to 65537 {@code bytes} from the given {@link ByteBuf} where the first two bytes + * represent username length and the subsequent number of bytes equal to read length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length position + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if username length is zero + */ + public static ByteBuf readUsername(ByteBuf simpleAuthMetadata) { + int usernameLength = readUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read password from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return sliced {@link ByteBuf} or {@link Unpooled#EMPTY_BUFFER} if password length is zero + */ + public static ByteBuf readPassword(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return Unpooled.EMPTY_BUFFER; + } + + return simpleAuthMetadata.readSlice(simpleAuthMetadata.readableBytes()); + } + /** + * Read up to 65537 {@code bytes} from the given {@link ByteBuf} where the first two bytes + * represent username length and the subsequent number of bytes equal to read length + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the username length byte + * @return {@code char[]} which represents UTF-8 username + */ + public static char[] readUsernameAsCharArray(ByteBuf simpleAuthMetadata) { + int usernameLength = readUsernameLength(simpleAuthMetadata); + + if (usernameLength == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, usernameLength); + } + + /** + * Read all the remaining {@code byte}s from the given {@link ByteBuf} which represents user's + * password + * + * @param simpleAuthMetadata the given metadata to read username from. Please note, the {@code + * simpleAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] readPasswordAsCharArray(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(simpleAuthMetadata, simpleAuthMetadata.readableBytes()); + } + + /** + * Read all the remaining {@code bytes} from the given {@link ByteBuf} + * + * @param bearerAuthMetadata the given metadata to read username from. Please note, the {@code + * bearerAuthMetadata#readIndex} should be set to the beginning of the password bytes + * @return {@code char[]} which represents UTF-8 password + */ + public static char[] readBearerTokenAsCharArray(ByteBuf bearerAuthMetadata) { + if (bearerAuthMetadata.readableBytes() == 0) { + return EMPTY_CHARS_ARRAY; + } + + return CharByteBufUtil.readUtf8(bearerAuthMetadata, bearerAuthMetadata.readableBytes()); + } + + private static int readUsernameLength(ByteBuf simpleAuthMetadata) { + if (simpleAuthMetadata.readableBytes() < 2) { + throw new IllegalStateException( + "Unable to decode custom username. Not enough readable bytes"); + } + + int usernameLength = simpleAuthMetadata.readUnsignedShort(); + + if (simpleAuthMetadata.readableBytes() < usernameLength) { + throw new IllegalArgumentException( + "Unable to decode username. Malformed username length or content"); + } + + return usernameLength; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java new file mode 100644 index 000000000..1c3ae9423 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadata.java @@ -0,0 +1,241 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static io.rsocket.metadata.CompositeMetadataCodec.computeNextEntryIndex; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.hasEntry; +import static io.rsocket.metadata.CompositeMetadataCodec.isWellKnownMimeType; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.metadata.CompositeMetadata.Entry; +import java.util.Iterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import reactor.util.annotation.Nullable; + +/** + * An {@link Iterable} wrapper around a {@link ByteBuf} that exposes metadata entry information at + * each decoding step. This is only possible on frame types used to initiate interactions, if the + * SETUP metadata mime type was {@link WellKnownMimeType#MESSAGE_RSOCKET_COMPOSITE_METADATA}. + * + *

This allows efficient incremental decoding of the entries (without moving the source's {@link + * io.netty.buffer.ByteBuf#readerIndex()}). The buffer is assumed to contain just enough bytes to + * represent one or more entries (mime type compressed or not). The decoding stops when the buffer + * reaches 0 readable bytes, and fails if it contains bytes but not enough to correctly decode an + * entry. + * + *

A note on future-proofness: it is possible to come across a compressed mime type that this + * implementation doesn't recognize. This is likely to be due to the use of a byte id that is merely + * reserved in this implementation, but maps to a {@link WellKnownMimeType} in the implementation + * that encoded the metadata. This can be detected by detecting that an entry is a {@link + * ReservedMimeTypeEntry}. In this case {@link Entry#getMimeType()} will return {@code null}. The + * encoded id can be retrieved using {@link ReservedMimeTypeEntry#getType()}. The byte and content + * buffer should be kept around and re-encoded using {@link + * CompositeMetadataCodec#encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, byte, ByteBuf)} + * in case passing that entry through is required. + */ +public final class CompositeMetadata implements Iterable { + + private final boolean retainSlices; + + private final ByteBuf source; + + public CompositeMetadata(ByteBuf source, boolean retainSlices) { + this.source = source; + this.retainSlices = retainSlices; + } + + /** + * Turn this {@link CompositeMetadata} into a sequential {@link Stream}. + * + * @return the composite metadata sequential {@link Stream} + */ + public Stream stream() { + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize( + iterator(), Spliterator.DISTINCT | Spliterator.NONNULL | Spliterator.ORDERED), + false); + } + + /** + * An {@link Iterator} that lazily decodes {@link Entry} in this composite metadata. + * + * @return the composite metadata {@link Iterator} + */ + @Override + public Iterator iterator() { + return new Iterator() { + + private int entryIndex = 0; + + @Override + public boolean hasNext() { + return hasEntry(CompositeMetadata.this.source, this.entryIndex); + } + + @Override + public Entry next() { + ByteBuf[] headerAndData = + decodeMimeAndContentBuffersSlices( + CompositeMetadata.this.source, + this.entryIndex, + CompositeMetadata.this.retainSlices); + + ByteBuf header = headerAndData[0]; + ByteBuf data = headerAndData[1]; + + this.entryIndex = computeNextEntryIndex(this.entryIndex, header, data); + + if (!isWellKnownMimeType(header)) { + CharSequence typeString = decodeMimeTypeFromMimeBuffer(header); + if (typeString == null) { + throw new IllegalStateException("MIME type cannot be null"); + } + + return new ExplicitMimeTimeEntry(data, typeString.toString()); + } + + byte id = decodeMimeIdFromMimeBuffer(header); + WellKnownMimeType type = WellKnownMimeType.fromIdentifier(id); + + if (WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE == type) { + return new ReservedMimeTypeEntry(data, id); + } + + return new WellKnownMimeTypeEntry(data, type); + } + }; + } + + /** An entry in the {@link CompositeMetadata}. */ + public interface Entry { + + /** + * Returns the un-decoded content of the {@link Entry}. + * + * @return the un-decoded content of the {@link Entry} + */ + ByteBuf getContent(); + + /** + * Returns the MIME type of the entry, if it can be decoded. + * + * @return the MIME type of the entry, if it can be decoded, otherwise {@code null}. + */ + @Nullable + String getMimeType(); + } + + /** An {@link Entry} backed by an explicitly declared MIME type. */ + public static final class ExplicitMimeTimeEntry implements Entry { + + private final ByteBuf content; + + private final String type; + + public ExplicitMimeTimeEntry(ByteBuf content, String type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.type; + } + } + + /** + * An {@link Entry} backed by a {@link WellKnownMimeType} entry, but one that is not understood by + * this implementation. + */ + public static final class ReservedMimeTypeEntry implements Entry { + private final ByteBuf content; + private final int type; + + public ReservedMimeTypeEntry(ByteBuf content, int type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + /** + * {@inheritDoc} Since this entry represents a compressed id that couldn't be decoded, this is + * always {@code null}. + */ + @Override + public String getMimeType() { + return null; + } + + /** + * Returns the reserved, but unknown {@link WellKnownMimeType} for this entry. Range is 0-127 + * (inclusive). + * + * @return the reserved, but unknown {@link WellKnownMimeType} for this entry + */ + public int getType() { + return this.type; + } + } + + /** An {@link Entry} backed by a {@link WellKnownMimeType}. */ + public static final class WellKnownMimeTypeEntry implements Entry { + + private final ByteBuf content; + private final WellKnownMimeType type; + + public WellKnownMimeTypeEntry(ByteBuf content, WellKnownMimeType type) { + this.content = content; + this.type = type; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.type.getString(); + } + + /** + * Returns the {@link WellKnownMimeType} for this entry. + * + * @return the {@link WellKnownMimeType} for this entry + */ + public WellKnownMimeType getType() { + return this.type; + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java new file mode 100644 index 000000000..5e00abba8 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/CompositeMetadataCodec.java @@ -0,0 +1,385 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.CharsetUtil; +import io.rsocket.util.NumberUtils; +import reactor.util.annotation.Nullable; + +/** + * A flyweight class that can be used to encode/decode composite metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * CompositeMetadata} for an Iterator-like approach to decoding entries. + */ +public class CompositeMetadataCodec { + + static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + private CompositeMetadataCodec() {} + + public static int computeNextEntryIndex( + int currentEntryIndex, ByteBuf headerSlice, ByteBuf contentSlice) { + return currentEntryIndex + + headerSlice.readableBytes() // this includes the mime length byte + + 3 // 3 bytes of the content length, which are excluded from the slice + + contentSlice.readableBytes(); + } + + /** + * Decode the next metadata entry (a mime header + content pair of {@link ByteBuf}) from a {@link + * ByteBuf} that contains at least enough bytes for one more such entry. These buffers are + * actually slices of the full metadata buffer, and this method doesn't move the full metadata + * buffer's {@link ByteBuf#readerIndex()}. As such, it requires the user to provide an {@code + * index} to read from. The next index is computed by calling {@link #computeNextEntryIndex(int, + * ByteBuf, ByteBuf)}. Size of the first buffer (the "header buffer") drives which decoding method + * should be further applied to it. + * + *

The header buffer is either: + * + *

    + *
  • made up of a single byte: this represents an encoded mime id, which can be further + * decoded using {@link #decodeMimeIdFromMimeBuffer(ByteBuf)} + *
  • made up of 2 or more bytes: this represents an encoded mime String + its length, which + * can be further decoded using {@link #decodeMimeTypeFromMimeBuffer(ByteBuf)}. Note the + * encoded length, in the first byte, is skipped by this decoding method because the + * remaining length of the buffer is that of the mime string. + *
+ * + * @param compositeMetadata the source {@link ByteBuf} that originally contains one or more + * metadata entries + * @param entryIndex the {@link ByteBuf#readerIndex()} to start decoding from. original reader + * index is kept on the source buffer + * @param retainSlices should produced metadata entry buffers {@link ByteBuf#slice() slices} be + * {@link ByteBuf#retainedSlice() retained}? + * @return a {@link ByteBuf} array of length 2 containing the mime header buffer + * slice and the content buffer slice, or one of the + * zero-length error constant arrays + */ + public static ByteBuf[] decodeMimeAndContentBuffersSlices( + ByteBuf compositeMetadata, int entryIndex, boolean retainSlices) { + compositeMetadata.markReaderIndex(); + compositeMetadata.readerIndex(entryIndex); + + if (compositeMetadata.isReadable()) { + ByteBuf mime; + int ridx = compositeMetadata.readerIndex(); + byte mimeIdOrLength = compositeMetadata.readByte(); + if ((mimeIdOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { + mime = + retainSlices + ? compositeMetadata.retainedSlice(ridx, 1) + : compositeMetadata.slice(ridx, 1); + } else { + // M flag unset, remaining 7 bits are the length of the mime + int mimeLength = Byte.toUnsignedInt(mimeIdOrLength) + 1; + + if (compositeMetadata.isReadable( + mimeLength)) { // need to be able to read an extra mimeLength bytes + // here we need a way for the returned ByteBuf to differentiate between a + // 1-byte length mime type and a 1 byte encoded mime id, preferably without + // re-applying the byte mask. The easiest way is to include the initial byte + // and have further decoding ignore the first byte. 1 byte buffer == id, 2+ byte + // buffer == full mime string. + mime = + retainSlices + ? + // we accommodate that we don't read from current readerIndex, but + // readerIndex - 1 ("0"), for a total slice size of mimeLength + 1 + compositeMetadata.retainedSlice(ridx, mimeLength + 1) + : compositeMetadata.slice(ridx, mimeLength + 1); + // we thus need to skip the bytes we just sliced, but not the flag/length byte + // which was already skipped in initial read + compositeMetadata.skipBytes(mimeLength); + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + + if (compositeMetadata.isReadable(3)) { + // ensures the length medium can be read + final int metadataLength = compositeMetadata.readUnsignedMedium(); + if (compositeMetadata.isReadable(metadataLength)) { + ByteBuf metadata = + retainSlices + ? compositeMetadata.readRetainedSlice(metadataLength) + : compositeMetadata.readSlice(metadataLength); + compositeMetadata.resetReaderIndex(); + return new ByteBuf[] {mime, metadata}; + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } else { + compositeMetadata.resetReaderIndex(); + throw new IllegalStateException("metadata is malformed"); + } + } + compositeMetadata.resetReaderIndex(); + throw new IllegalArgumentException( + String.format("entry index %d is larger than buffer size", entryIndex)); + } + + /** + * Decode a {@code byte} compressed mime id from a {@link ByteBuf}, assuming said buffer properly + * contains such an id. + * + *

The buffer must have exactly one readable byte, which is assumed to have been tested for + * mime id encoding via the {@link #STREAM_METADATA_KNOWN_MASK} mask ({@code firstByte & + * STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK}). + * + *

If there is no readable byte, the negative identifier of {@link + * WellKnownMimeType#UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeBuffer the buffer that should next contain the compressed mime id byte + * @return the compressed mime id, between 0 and 127, or a negative id if the input is invalid + * @see #decodeMimeTypeFromMimeBuffer(ByteBuf) + */ + public static byte decodeMimeIdFromMimeBuffer(ByteBuf mimeBuffer) { + if (mimeBuffer.readableBytes() != 1) { + return WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier(); + } + return (byte) (mimeBuffer.readByte() & STREAM_METADATA_LENGTH_MASK); + } + + /** + * Decode a {@link CharSequence} custome mime type from a {@link ByteBuf}, assuming said buffer + * properly contains such a mime type. + * + *

The buffer must at least have two readable bytes, which distinguishes it from the {@link + * #decodeMimeIdFromMimeBuffer(ByteBuf) compressed id} case. The first byte is a size and the + * remaining bytes must correspond to the {@link CharSequence}, encoded fully in US_ASCII. As a + * result, the first byte can simply be skipped, and the remaining of the buffer be decoded to the + * mime type. + * + *

If the mime header buffer is less than 2 bytes long, returns {@code null}. + * + * @param flyweightMimeBuffer the mime header {@link ByteBuf} that contains length + custom mime + * type + * @return the decoded custom mime type, as a {@link CharSequence}, or null if the input is + * invalid + * @see #decodeMimeIdFromMimeBuffer(ByteBuf) + */ + @Nullable + public static CharSequence decodeMimeTypeFromMimeBuffer(ByteBuf flyweightMimeBuffer) { + if (flyweightMimeBuffer.readableBytes() < 2) { + throw new IllegalStateException("unable to decode explicit MIME type"); + } + // the encoded length is assumed to be kept at the start of the buffer + // but also assumed to be irrelevant because the rest of the slice length + // actually already matches _decoded_length + flyweightMimeBuffer.skipBytes(1); + int mimeStringLength = flyweightMimeBuffer.readableBytes(); + return flyweightMimeBuffer.readCharSequence(mimeStringLength, CharsetUtil.US_ASCII); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, without checking if the {@link String} can be matched with a well known compressable + * mime type. Prefer using this method and {@link #encodeAndAddMetadata(CompositeByteBuf, + * ByteBufAllocator, WellKnownMimeType, ByteBuf)} if you know in advance whether or not the mime + * is well known. Otherwise use {@link #encodeAndAddMetadataWithCompression(CompositeByteBuf, + * ByteBufAllocator, String, ByteBuf)} + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param customMimeType the custom mime type to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String customMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, customMimeType, metadata.readableBytes()), metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + public static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + WellKnownMimeType knownMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, knownMimeType.getIdentifier(), metadata.readableBytes()), + metadata); + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}, first verifying if the passed {@link String} matches a {@link WellKnownMimeType} (in + * which case it will be encoded in a compressed fashion using the mime id of that type). + * + *

Prefer using {@link #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, String, + * ByteBuf)} if you already know that the mime type is not a {@link WellKnownMimeType}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param mimeType the mime type to encode, as a {@link String}. well known mime types are + * compressed. + * @param metadata the metadata value to encode. + * @see #encodeAndAddMetadata(CompositeByteBuf, ByteBufAllocator, WellKnownMimeType, ByteBuf) + */ + // see #encodeMetadataHeader(ByteBufAllocator, String, int) + public static void encodeAndAddMetadataWithCompression( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + String mimeType, + ByteBuf metadata) { + WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); + if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { + compositeMetaData.addComponents( + true, encodeMetadataHeader(allocator, mimeType, metadata.readableBytes()), metadata); + } else { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, wkn.getIdentifier(), metadata.readableBytes()), + metadata); + } + } + + /** + * Returns whether there is another entry available at a given index + * + * @param compositeMetadata the buffer to inspect + * @param entryIndex the index to check at + * @return whether there is another entry available at a given index + */ + public static boolean hasEntry(ByteBuf compositeMetadata, int entryIndex) { + return compositeMetadata.writerIndex() - entryIndex > 0; + } + + /** + * Returns whether the header represents a well-known MIME type. + * + * @param header the header to inspect + * @return whether the header represents a well-known MIME type + */ + public static boolean isWellKnownMimeType(ByteBuf header) { + return header.readableBytes() == 1; + } + + /** + * Encode a new sub-metadata information into a composite metadata {@link CompositeByteBuf + * buffer}. + * + * @param compositeMetaData the buffer that will hold all composite metadata information. + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param unknownCompressedMimeType the id of the {@link + * WellKnownMimeType#UNKNOWN_RESERVED_MIME_TYPE} to encode. + * @param metadata the metadata value to encode. + */ + // see #encodeMetadataHeader(ByteBufAllocator, byte, int) + static void encodeAndAddMetadata( + CompositeByteBuf compositeMetaData, + ByteBufAllocator allocator, + byte unknownCompressedMimeType, + ByteBuf metadata) { + compositeMetaData.addComponents( + true, + encodeMetadataHeader(allocator, unknownCompressedMimeType, metadata.readableBytes()), + metadata); + } + + /** + * Encode a custom mime type and a metadata value length into a newly allocated {@link ByteBuf}. + * + *

This larger representation encodes the mime type representation's length on a single byte, + * then the representation itself, then the unsigned metadata value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param customMime a custom mime type to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, String customMime, int metadataLength) { + ByteBuf metadataHeader = allocator.buffer(4 + customMime.length()); + // reserve 1 byte for the customMime length + // /!\ careful not to read that first byte, which is random at this point + int writerIndexInitial = metadataHeader.writerIndex(); + metadataHeader.writerIndex(writerIndexInitial + 1); + + // write the custom mime in UTF8 but validate it is all ASCII-compatible + // (which produces the right result since ASCII chars are still encoded on 1 byte in UTF8) + int customMimeLength = ByteBufUtil.writeUtf8(metadataHeader, customMime); + if (!ByteBufUtil.isText( + metadataHeader, metadataHeader.readerIndex() + 1, customMimeLength, CharsetUtil.US_ASCII)) { + metadataHeader.release(); + throw new IllegalArgumentException("custom mime type must be US_ASCII characters only"); + } + if (customMimeLength < 1 || customMimeLength > 128) { + metadataHeader.release(); + throw new IllegalArgumentException( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + metadataHeader.markWriterIndex(); + + // go back to beginning and write the length + // encoded length is one less than actual length, since 0 is never a valid length, which gives + // wider representation range + metadataHeader.writerIndex(writerIndexInitial); + metadataHeader.writeByte(customMimeLength - 1); + + // go back to post-mime type and write the metadata content length + metadataHeader.resetWriterIndex(); + NumberUtils.encodeUnsignedMedium(metadataHeader, metadataLength); + + return metadataHeader; + } + + /** + * Encode a {@link WellKnownMimeType well known mime type} and a metadata value length into a + * newly allocated {@link ByteBuf}. + * + *

This compact representation encodes the mime type via its ID on a single byte, and the + * unsigned value length on 3 additional bytes. + * + * @param allocator the {@link ByteBufAllocator} to use to create the buffer. + * @param mimeType a byte identifier of a {@link WellKnownMimeType} to encode. + * @param metadataLength the metadata length to append to the buffer as an unsigned 24 bits + * integer. + * @return the encoded mime and metadata length information + */ + static ByteBuf encodeMetadataHeader( + ByteBufAllocator allocator, byte mimeType, int metadataLength) { + ByteBuf buffer = allocator.buffer(4, 4).writeByte(mimeType | STREAM_METADATA_KNOWN_MASK); + + NumberUtils.encodeUnsignedMedium(buffer, metadataLength); + + return buffer; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java new file mode 100644 index 000000000..2e03bd754 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/MimeTypeMetadataCodec.java @@ -0,0 +1,137 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.CharsetUtil; +import java.util.ArrayList; +import java.util.List; + +/** + * Provides support for encoding and decoding the per-stream MIME type to use for payload data. + * + *

For more on the format of the metadata, see the + * Stream Data MIME Types extension specification. + * + * @since 1.1.1 + */ +public class MimeTypeMetadataCodec { + + private static final int STREAM_METADATA_KNOWN_MASK = 0x80; // 1000 0000 + + private static final byte STREAM_METADATA_LENGTH_MASK = 0x7F; // 0111 1111 + + private MimeTypeMetadataCodec() {} + + /** + * Encode a {@link WellKnownMimeType} into a newly allocated single byte {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeType well-known MIME type to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, WellKnownMimeType mimeType) { + return allocator.buffer(1, 1).writeByte(mimeType.getIdentifier() | STREAM_METADATA_KNOWN_MASK); + } + + /** + * Encode the given MIME type into a newly allocated {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeType MIME type to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, String mimeType) { + if (mimeType == null || mimeType.length() == 0) { + throw new IllegalArgumentException("MIME type is required"); + } + WellKnownMimeType wkn = WellKnownMimeType.fromString(mimeType); + if (wkn == WellKnownMimeType.UNPARSEABLE_MIME_TYPE) { + return encodeCustomMimeType(allocator, mimeType); + } else { + return encode(allocator, wkn); + } + } + + /** + * Encode multiple MIME types into a newly allocated {@link ByteBuf}. + * + * @param allocator the allocator to create the buffer with + * @param mimeTypes MIME types to encode + * @return the resulting buffer + */ + public static ByteBuf encode(ByteBufAllocator allocator, List mimeTypes) { + if (mimeTypes == null || mimeTypes.size() == 0) { + throw new IllegalArgumentException("No MIME types provided"); + } + CompositeByteBuf compositeByteBuf = allocator.compositeBuffer(); + for (String mimeType : mimeTypes) { + ByteBuf byteBuf = encode(allocator, mimeType); + compositeByteBuf.addComponents(true, byteBuf); + } + return compositeByteBuf; + } + + private static ByteBuf encodeCustomMimeType(ByteBufAllocator allocator, String customMimeType) { + ByteBuf byteBuf = allocator.buffer(1 + customMimeType.length()); + + byteBuf.writerIndex(1); + int length = ByteBufUtil.writeUtf8(byteBuf, customMimeType); + + if (!ByteBufUtil.isText(byteBuf, 1, length, CharsetUtil.US_ASCII)) { + byteBuf.release(); + throw new IllegalArgumentException("MIME type must be ASCII characters only"); + } + + if (length < 1 || length > 128) { + byteBuf.release(); + throw new IllegalArgumentException( + "MIME type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + byteBuf.markWriterIndex(); + byteBuf.writerIndex(0); + byteBuf.writeByte(length - 1); + byteBuf.resetWriterIndex(); + + return byteBuf; + } + + /** + * Decode the per-stream MIME type metadata encoded in the given {@link ByteBuf}. + * + * @return the decoded MIME types + */ + public static List decode(ByteBuf byteBuf) { + List mimeTypes = new ArrayList<>(); + while (byteBuf.isReadable()) { + byte idOrLength = byteBuf.readByte(); + if ((idOrLength & STREAM_METADATA_KNOWN_MASK) == STREAM_METADATA_KNOWN_MASK) { + byte id = (byte) (idOrLength & STREAM_METADATA_LENGTH_MASK); + WellKnownMimeType wellKnownMimeType = WellKnownMimeType.fromIdentifier(id); + mimeTypes.add(wellKnownMimeType.toString()); + } else { + int length = Byte.toUnsignedInt(idOrLength) + 1; + mimeTypes.add(byteBuf.readCharSequence(length, CharsetUtil.US_ASCII).toString()); + } + } + return mimeTypes; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/RoutingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/RoutingMetadata.java new file mode 100644 index 000000000..d1f2643dc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/RoutingMetadata.java @@ -0,0 +1,18 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; + +/** + * Routing Metadata extension from + * https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + * + * @author linux_china + */ +public class RoutingMetadata extends TaggingMetadata { + private static final WellKnownMimeType ROUTING_MIME_TYPE = + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING; + + public RoutingMetadata(ByteBuf content) { + super(ROUTING_MIME_TYPE.getString(), content); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadata.java new file mode 100644 index 000000000..e22d97106 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadata.java @@ -0,0 +1,64 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import java.nio.charset.StandardCharsets; +import java.util.Iterator; +import java.util.Spliterator; +import java.util.Spliterators; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +/** + * Tagging metadata from https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + * + * @author linux_china + */ +public class TaggingMetadata implements Iterable, CompositeMetadata.Entry { + /** Tag max length in bytes */ + private static int TAG_LENGTH_MAX = 0xFF; + + private String mimeType; + private ByteBuf content; + + public TaggingMetadata(String mimeType, ByteBuf content) { + this.mimeType = mimeType; + this.content = content; + } + + public Stream stream() { + return StreamSupport.stream( + Spliterators.spliteratorUnknownSize( + iterator(), Spliterator.DISTINCT | Spliterator.NONNULL | Spliterator.ORDERED), + false); + } + + @Override + public Iterator iterator() { + return new Iterator() { + @Override + public boolean hasNext() { + return content.readerIndex() < content.capacity(); + } + + @Override + public String next() { + int tagLength = TAG_LENGTH_MAX & content.readByte(); + if (tagLength > 0) { + return content.readSlice(tagLength).toString(StandardCharsets.UTF_8); + } else { + return ""; + } + } + }; + } + + @Override + public ByteBuf getContent() { + return this.content; + } + + @Override + public String getMimeType() { + return this.mimeType; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java new file mode 100644 index 000000000..d766cf59f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TaggingMetadataCodec.java @@ -0,0 +1,76 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import java.nio.charset.StandardCharsets; +import java.util.Collection; + +/** + * A flyweight class that can be used to encode/decode tagging metadata information to/from {@link + * ByteBuf}. This is intended for low-level efficient manipulation of such buffers. See {@link + * TaggingMetadata} for an Iterator-like approach to decoding entries. + * + * @author linux_china + */ +public class TaggingMetadataCodec { + /** Tag max length in bytes */ + private static int TAG_LENGTH_MAX = 0xFF; + + /** + * create routing metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return routing metadata + */ + public static RoutingMetadata createRoutingMetadata( + ByteBufAllocator allocator, Collection tags) { + return new RoutingMetadata(createTaggingContent(allocator, tags)); + } + + /** + * create tagging metadata from composite metadata entry + * + * @param entry composite metadata entry + * @return tagging metadata + */ + public static TaggingMetadata createTaggingMetadata(CompositeMetadata.Entry entry) { + return new TaggingMetadata(entry.getMimeType(), entry.getContent()); + } + + /** + * create tagging metadata + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param knownMimeType the {@link WellKnownMimeType} to encode. + * @param tags tag values + * @return Tagging Metadata + */ + public static TaggingMetadata createTaggingMetadata( + ByteBufAllocator allocator, String knownMimeType, Collection tags) { + return new TaggingMetadata(knownMimeType, createTaggingContent(allocator, tags)); + } + + /** + * create tagging content + * + * @param allocator the {@link ByteBufAllocator} to use to create intermediate buffers as needed. + * @param tags tag values + * @return tagging content + */ + public static ByteBuf createTaggingContent(ByteBufAllocator allocator, Collection tags) { + CompositeByteBuf taggingContent = allocator.compositeBuffer(); + for (String key : tags) { + int length = ByteBufUtil.utf8Bytes(key); + if (length == 0 || length > TAG_LENGTH_MAX) { + continue; + } + ByteBuf byteBuf = allocator.buffer().writeByte(length); + byteBuf.writeCharSequence(key, StandardCharsets.UTF_8); + taggingContent.addComponent(true, byteBuf); + } + return taggingContent; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java new file mode 100644 index 000000000..d276a9436 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadata.java @@ -0,0 +1,110 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +/** + * Represents decoded tracing metadata which is fully compatible with Zipkin B3 propagation + * + * @since 1.0 + */ +public final class TracingMetadata { + + final long traceIdHigh; + final long traceId; + private final boolean hasParentId; + final long parentId; + final long spanId; + final boolean isEmpty; + final boolean isNotSampled; + final boolean isSampled; + final boolean isDebug; + + TracingMetadata( + long traceIdHigh, + long traceId, + long spanId, + boolean hasParentId, + long parentId, + boolean isEmpty, + boolean isNotSampled, + boolean isSampled, + boolean isDebug) { + this.traceIdHigh = traceIdHigh; + this.traceId = traceId; + this.spanId = spanId; + this.hasParentId = hasParentId; + this.parentId = parentId; + this.isEmpty = isEmpty; + this.isNotSampled = isNotSampled; + this.isSampled = isSampled; + this.isDebug = isDebug; + } + + /** When non-zero, the trace containing this span uses 128-bit trace identifiers. */ + public long traceIdHigh() { + return traceIdHigh; + } + + /** Unique 8-byte identifier for a trace, set on all spans within it. */ + public long traceId() { + return traceId; + } + + /** Indicates if the parent's {@link #spanId} or if this the root span in a trace. */ + public final boolean hasParent() { + return hasParentId; + } + + /** Returns the parent's {@link #spanId} where zero implies absent. */ + public long parentId() { + return parentId; + } + + /** + * Unique 8-byte identifier of this span within a trace. + * + *

A span is uniquely identified in storage by ({@linkplain #traceId}, {@linkplain #spanId}). + */ + public long spanId() { + return spanId; + } + + /** Indicates that trace IDs should be accepted for tracing. */ + public boolean isSampled() { + return isSampled; + } + + /** Indicates that trace IDs should be force traced. */ + public boolean isDebug() { + return isDebug; + } + + /** Includes that there is sampling information and no trace IDs. */ + public boolean isEmpty() { + return isEmpty; + } + + /** + * Indicated that sampling decision is present. If {@code false} means that decision is unknown + * and says explicitly that {@link #isDebug()} and {@link #isSampled()} also returns {@code + * false}. + */ + public boolean isDecided() { + return isNotSampled || isDebug || isSampled; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java new file mode 100644 index 000000000..eb44956f6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/TracingMetadataCodec.java @@ -0,0 +1,172 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +/** + * Represents codes for tracing metadata which is fully compatible with Zipkin B3 propagation + * + * @since 1.0 + */ +public class TracingMetadataCodec { + + static final int FLAG_EXTENDED_TRACE_ID_SIZE = 0b0000_1000; + static final int FLAG_INCLUDE_PARENT_ID = 0b0000_0100; + static final int FLAG_NOT_SAMPLED = 0b0001_0000; + static final int FLAG_SAMPLED = 0b0010_0000; + static final int FLAG_DEBUG = 0b0100_0000; + static final int FLAG_IDS_SET = 0b1000_0000; + + public static ByteBuf encodeEmpty(ByteBufAllocator allocator, Flags flag) { + + return encode(allocator, true, 0, 0, false, 0, 0, false, flag); + } + + public static ByteBuf encode128( + ByteBufAllocator allocator, + long traceIdHigh, + long traceId, + long spanId, + long parentId, + Flags flag) { + + return encode(allocator, false, traceIdHigh, traceId, true, spanId, parentId, true, flag); + } + + public static ByteBuf encode128( + ByteBufAllocator allocator, long traceIdHigh, long traceId, long spanId, Flags flag) { + + return encode(allocator, false, traceIdHigh, traceId, true, spanId, 0, false, flag); + } + + public static ByteBuf encode64( + ByteBufAllocator allocator, long traceId, long spanId, long parentId, Flags flag) { + + return encode(allocator, false, 0, traceId, false, spanId, parentId, true, flag); + } + + public static ByteBuf encode64( + ByteBufAllocator allocator, long traceId, long spanId, Flags flag) { + return encode(allocator, false, 0, traceId, false, spanId, 0, false, flag); + } + + static ByteBuf encode( + ByteBufAllocator allocator, + boolean isEmpty, + long traceIdHigh, + long traceId, + boolean extendedTraceId, + long spanId, + long parentId, + boolean includesParent, + Flags flag) { + int size = + 1 + + (isEmpty + ? 0 + : (Long.BYTES + + Long.BYTES + + (extendedTraceId ? Long.BYTES : 0) + + (includesParent ? Long.BYTES : 0))); + final ByteBuf buffer = allocator.buffer(size); + + int byteFlags = 0; + switch (flag) { + case NOT_SAMPLE: + byteFlags |= FLAG_NOT_SAMPLED; + break; + case SAMPLE: + byteFlags |= FLAG_SAMPLED; + break; + case DEBUG: + byteFlags |= FLAG_DEBUG; + break; + } + + if (isEmpty) { + return buffer.writeByte(byteFlags); + } + + byteFlags |= FLAG_IDS_SET; + + if (extendedTraceId) { + byteFlags |= FLAG_EXTENDED_TRACE_ID_SIZE; + } + + if (includesParent) { + byteFlags |= FLAG_INCLUDE_PARENT_ID; + } + + buffer.writeByte(byteFlags); + + if (extendedTraceId) { + buffer.writeLong(traceIdHigh); + } + + buffer.writeLong(traceId).writeLong(spanId); + + if (includesParent) { + buffer.writeLong(parentId); + } + + return buffer; + } + + public static TracingMetadata decode(ByteBuf byteBuf) { + byteBuf.markReaderIndex(); + try { + byte flags = byteBuf.readByte(); + boolean isNotSampled = (flags & FLAG_NOT_SAMPLED) == FLAG_NOT_SAMPLED; + boolean isSampled = (flags & FLAG_SAMPLED) == FLAG_SAMPLED; + boolean isDebug = (flags & FLAG_DEBUG) == FLAG_DEBUG; + boolean isIDSet = (flags & FLAG_IDS_SET) == FLAG_IDS_SET; + + if (!isIDSet) { + return new TracingMetadata(0, 0, 0, false, 0, true, isNotSampled, isSampled, isDebug); + } + + boolean extendedTraceId = + (flags & FLAG_EXTENDED_TRACE_ID_SIZE) == FLAG_EXTENDED_TRACE_ID_SIZE; + + long traceIdHigh; + if (extendedTraceId) { + traceIdHigh = byteBuf.readLong(); + } else { + traceIdHigh = 0; + } + + long traceId = byteBuf.readLong(); + long spanId = byteBuf.readLong(); + + boolean includesParent = (flags & FLAG_INCLUDE_PARENT_ID) == FLAG_INCLUDE_PARENT_ID; + + long parentId; + if (includesParent) { + parentId = byteBuf.readLong(); + } else { + parentId = 0; + } + + return new TracingMetadata( + traceIdHigh, + traceId, + spanId, + includesParent, + parentId, + false, + isNotSampled, + isSampled, + isDebug); + } finally { + byteBuf.resetReaderIndex(); + } + } + + public enum Flags { + UNDECIDED, + NOT_SAMPLE, + SAMPLE, + DEBUG + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java new file mode 100644 index 000000000..66c98701c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownAuthType.java @@ -0,0 +1,121 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import java.util.Arrays; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Auth Types, as defined in the eponymous extension. Such auth types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + */ +public enum WellKnownAuthType { + UNPARSEABLE_AUTH_TYPE("UNPARSEABLE_AUTH_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_AUTH_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + SIMPLE("simple", (byte) 0x00), + BEARER("bearer", (byte) 0x01); + // ... reserved for future use ... + + static final WellKnownAuthType[] TYPES_BY_AUTH_ID; + static final Map TYPES_BY_AUTH_STRING; + + static { + // precompute an array of all valid auth ids, filling the blanks with the RESERVED enum + TYPES_BY_AUTH_ID = new WellKnownAuthType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_AUTH_ID, UNKNOWN_RESERVED_AUTH_TYPE); + // also prepare a Map of the types by auth string + TYPES_BY_AUTH_STRING = new LinkedHashMap<>(128); + + for (WellKnownAuthType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_AUTH_ID[value.getIdentifier()] = value; + TYPES_BY_AUTH_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownAuthType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + /** + * Find the {@link WellKnownAuthType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_AUTH_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_AUTH_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownAuthType}, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_AUTH_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownAuthType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_AUTH_TYPE; + } + return TYPES_BY_AUTH_ID[id]; + } + + /** + * Find the {@link WellKnownAuthType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownAuthType}, the {@link + * #UNPARSEABLE_AUTH_TYPE} is returned. + * + * @param authType the looked up auth type + * @return the matching {@link WellKnownAuthType}, or {@link #UNPARSEABLE_AUTH_TYPE} if none + * matches + */ + public static WellKnownAuthType fromString(String authType) { + if (authType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_AUTH_TYPE's text has been used + if (authType.equals(UNKNOWN_RESERVED_AUTH_TYPE.str)) { + return UNPARSEABLE_AUTH_TYPE; + } + + return TYPES_BY_AUTH_STRING.getOrDefault(authType, UNPARSEABLE_AUTH_TYPE); + } + + /** @return the byte identifier of the auth type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the auth type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java new file mode 100644 index 000000000..e78e87629 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/metadata/WellKnownMimeType.java @@ -0,0 +1,167 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Enumeration of Well Known Mime Types, as defined in the eponymous extension. Such mime types are + * used in composite metadata (which can include routing and/or tracing metadata). Per + * specification, identifiers are between 0 and 127 (inclusive). + */ +public enum WellKnownMimeType { + UNPARSEABLE_MIME_TYPE("UNPARSEABLE_MIME_TYPE_DO_NOT_USE", (byte) -2), + UNKNOWN_RESERVED_MIME_TYPE("UNKNOWN_YET_RESERVED_DO_NOT_USE", (byte) -1), + + APPLICATION_AVRO("application/avro", (byte) 0x00), + APPLICATION_CBOR("application/cbor", (byte) 0x01), + APPLICATION_GRAPHQL("application/graphql", (byte) 0x02), + APPLICATION_GZIP("application/gzip", (byte) 0x03), + APPLICATION_JAVASCRIPT("application/javascript", (byte) 0x04), + APPLICATION_JSON("application/json", (byte) 0x05), + APPLICATION_OCTET_STREAM("application/octet-stream", (byte) 0x06), + APPLICATION_PDF("application/pdf", (byte) 0x07), + APPLICATION_THRIFT("application/vnd.apache.thrift.binary", (byte) 0x08), + APPLICATION_PROTOBUF("application/vnd.google.protobuf", (byte) 0x09), + APPLICATION_XML("application/xml", (byte) 0x0A), + APPLICATION_ZIP("application/zip", (byte) 0x0B), + AUDIO_AAC("audio/aac", (byte) 0x0C), + AUDIO_MP3("audio/mp3", (byte) 0x0D), + AUDIO_MP4("audio/mp4", (byte) 0x0E), + AUDIO_MPEG3("audio/mpeg3", (byte) 0x0F), + AUDIO_MPEG("audio/mpeg", (byte) 0x10), + AUDIO_OGG("audio/ogg", (byte) 0x11), + AUDIO_OPUS("audio/opus", (byte) 0x12), + AUDIO_VORBIS("audio/vorbis", (byte) 0x13), + IMAGE_BMP("image/bmp", (byte) 0x14), + IMAGE_GIF("image/gif", (byte) 0x15), + IMAGE_HEIC_SEQUENCE("image/heic-sequence", (byte) 0x16), + IMAGE_HEIC("image/heic", (byte) 0x17), + IMAGE_HEIF_SEQUENCE("image/heif-sequence", (byte) 0x18), + IMAGE_HEIF("image/heif", (byte) 0x19), + IMAGE_JPEG("image/jpeg", (byte) 0x1A), + IMAGE_PNG("image/png", (byte) 0x1B), + IMAGE_TIFF("image/tiff", (byte) 0x1C), + MULTIPART_MIXED("multipart/mixed", (byte) 0x1D), + TEXT_CSS("text/css", (byte) 0x1E), + TEXT_CSV("text/csv", (byte) 0x1F), + TEXT_HTML("text/html", (byte) 0x20), + TEXT_PLAIN("text/plain", (byte) 0x21), + TEXT_XML("text/xml", (byte) 0x22), + VIDEO_H264("video/H264", (byte) 0x23), + VIDEO_H265("video/H265", (byte) 0x24), + VIDEO_VP8("video/VP8", (byte) 0x25), + APPLICATION_HESSIAN("application/x-hessian", (byte) 0x26), + APPLICATION_JAVA_OBJECT("application/x-java-object", (byte) 0x27), + APPLICATION_CLOUDEVENTS_JSON("application/cloudevents+json", (byte) 0x28), + + // ... reserved for future use ... + MESSAGE_RSOCKET_MIMETYPE("message/x.rsocket.mime-type.v0", (byte) 0x7A), + MESSAGE_RSOCKET_ACCEPT_MIMETYPES("message/x.rsocket.accept-mime-types.v0", (byte) 0x7B), + MESSAGE_RSOCKET_AUTHENTICATION("message/x.rsocket.authentication.v0", (byte) 0x7C), + MESSAGE_RSOCKET_TRACING_ZIPKIN("message/x.rsocket.tracing-zipkin.v0", (byte) 0x7D), + MESSAGE_RSOCKET_ROUTING("message/x.rsocket.routing.v0", (byte) 0x7E), + MESSAGE_RSOCKET_COMPOSITE_METADATA("message/x.rsocket.composite-metadata.v0", (byte) 0x7F); + + static final WellKnownMimeType[] TYPES_BY_MIME_ID; + static final Map TYPES_BY_MIME_STRING; + + static { + // precompute an array of all valid mime ids, filling the blanks with the RESERVED enum + TYPES_BY_MIME_ID = new WellKnownMimeType[128]; // 0-127 inclusive + Arrays.fill(TYPES_BY_MIME_ID, UNKNOWN_RESERVED_MIME_TYPE); + // also prepare a Map of the types by mime string + TYPES_BY_MIME_STRING = new HashMap<>(128); + + for (WellKnownMimeType value : values()) { + if (value.getIdentifier() >= 0) { + TYPES_BY_MIME_ID[value.getIdentifier()] = value; + TYPES_BY_MIME_STRING.put(value.getString(), value); + } + } + } + + private final byte identifier; + private final String str; + + WellKnownMimeType(String str, byte identifier) { + this.str = str; + this.identifier = identifier; + } + + /** + * Find the {@link WellKnownMimeType} for the given identifier (as an {@code int}). Valid + * identifiers are defined to be integers between 0 and 127, inclusive. Identifiers outside of + * this range will produce the {@link #UNPARSEABLE_MIME_TYPE}. Additionally, some identifiers in + * that range are still only reserved and don't have a type associated yet: this method returns + * the {@link #UNKNOWN_RESERVED_MIME_TYPE} when passing such an identifier, which lets call sites + * potentially detect this and keep the original representation when transmitting the associated + * metadata buffer. + * + * @param id the looked up identifier + * @return the {@link WellKnownMimeType}, or {@link #UNKNOWN_RESERVED_MIME_TYPE} if the id is out + * of the specification's range, or {@link #UNKNOWN_RESERVED_MIME_TYPE} if the id is one that + * is merely reserved but unknown to this implementation. + */ + public static WellKnownMimeType fromIdentifier(int id) { + if (id < 0x00 || id > 0x7F) { + return UNPARSEABLE_MIME_TYPE; + } + return TYPES_BY_MIME_ID[id]; + } + + /** + * Find the {@link WellKnownMimeType} for the given {@link String} representation. If the + * representation is {@code null} or doesn't match a {@link WellKnownMimeType}, the {@link + * #UNPARSEABLE_MIME_TYPE} is returned. + * + * @param mimeType the looked up mime type + * @return the matching {@link WellKnownMimeType}, or {@link #UNPARSEABLE_MIME_TYPE} if none + * matches + */ + public static WellKnownMimeType fromString(String mimeType) { + if (mimeType == null) throw new IllegalArgumentException("type must be non-null"); + + // force UNPARSEABLE if by chance UNKNOWN_RESERVED_MIME_TYPE's text has been used + if (mimeType.equals(UNKNOWN_RESERVED_MIME_TYPE.str)) { + return UNPARSEABLE_MIME_TYPE; + } + + return TYPES_BY_MIME_STRING.getOrDefault(mimeType, UNPARSEABLE_MIME_TYPE); + } + + /** @return the byte identifier of the mime type, guaranteed to be positive or zero. */ + public byte getIdentifier() { + return identifier; + } + + /** + * @return the mime type represented as a {@link String}, which is made of US_ASCII compatible + * characters only + */ + public String getString() { + return str; + } + + /** @see #getString() */ + @Override + public String toString() { + return str; + } +} diff --git a/src/main/java/io/reactivesocket/exceptions/ApplicationException.java b/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java similarity index 58% rename from src/main/java/io/reactivesocket/exceptions/ApplicationException.java rename to rsocket-core/src/main/java/io/rsocket/metadata/package-info.java index d31a21a05..3fb9ae1d6 100644 --- a/src/main/java/io/reactivesocket/exceptions/ApplicationException.java +++ b/rsocket-core/src/main/java/io/rsocket/metadata/package-info.java @@ -1,11 +1,11 @@ -/** - * Copyright 2015 Netflix, Inc. +/* + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket.exceptions; -public class ApplicationException extends Throwable { - public ApplicationException(String message) { - super(message); - } +/** + * Contains implementations of RSocket protocol extensions related + * to the use of metadata. + */ +@NonNullApi +package io.rsocket.metadata; - @Override - public synchronized Throwable fillInStackTrace() { - return this; - } -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/package-info.java b/rsocket-core/src/main/java/io/rsocket/package-info.java new file mode 100644 index 000000000..6fe74fb38 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/package-info.java @@ -0,0 +1,29 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains key contracts of the RSocket programming model including {@link io.rsocket.RSocket + * RSocket} for performing or handling RSocket interactions, {@link io.rsocket.SocketAcceptor + * SocketAcceptor} for declaring responders, {@link io.rsocket.Payload Payload} for access to the + * content of a payload, and others. + * + *

To connect to or start a server see {@link io.rsocket.core.RSocketConnector RSocketConnector} + * and {@link io.rsocket.core.RSocketServer RSocketServer} in {@link io.rsocket.core}. + */ +@NonNullApi +package io.rsocket; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java new file mode 100644 index 000000000..9a134153d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/CompositeRequestInterceptor.java @@ -0,0 +1,147 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import java.util.List; +import reactor.core.publisher.Operators; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +class CompositeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor[] requestInterceptors; + + CompositeRequestInterceptor(RequestInterceptor[] requestInterceptors) { + this.requestInterceptors = requestInterceptors; + } + + @Override + public void dispose() { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + requestInterceptor.dispose(); + } + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onTerminate(streamId, requestType, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onCancel(streamId, requestType); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + final RequestInterceptor[] requestInterceptors = this.requestInterceptors; + for (int i = 0; i < requestInterceptors.length; i++) { + final RequestInterceptor requestInterceptor = requestInterceptors[i]; + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } + + @Nullable + static RequestInterceptor create(List interceptors) { + switch (interceptors.size()) { + case 0: + return null; + case 1: + return new SafeRequestInterceptor(interceptors.get(0)); + default: + return new CompositeRequestInterceptor(interceptors.toArray(new RequestInterceptor[0])); + } + } + + static class SafeRequestInterceptor implements RequestInterceptor { + + final RequestInterceptor requestInterceptor; + + public SafeRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + } + + @Override + public void dispose() { + requestInterceptor.dispose(); + } + + @Override + public boolean isDisposed() { + return requestInterceptor.isDisposed(); + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onStart(streamId, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable cause) { + try { + requestInterceptor.onTerminate(streamId, requestType, cause); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + try { + requestInterceptor.onCancel(streamId, requestType); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + try { + requestInterceptor.onReject(rejectionReason, requestType, metadata); + } catch (Throwable t) { + Operators.onErrorDropped(t, Context.empty()); + } + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java new file mode 100644 index 000000000..5d3a43b03 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/DuplexConnectionInterceptor.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.plugins; + +import io.rsocket.DuplexConnection; +import java.util.function.BiFunction; + +/** + * Contract to decorate a {@link DuplexConnection} and intercept the sending and receiving of + * RSocket frames at the transport level. + */ +public @FunctionalInterface interface DuplexConnectionInterceptor + extends BiFunction { + + enum Type { + /** @deprecated since 1.1.0-M2. Will be removed in 1.2 */ + @Deprecated + SETUP, + CLIENT, + SERVER, + SOURCE + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java new file mode 100644 index 000000000..7c9a90f54 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InitializingInterceptorRegistry.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.DuplexConnection; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import reactor.util.annotation.Nullable; + +/** + * Extends {@link InterceptorRegistry} with methods for building a chain of registered interceptors. + * This is not intended for direct use by applications. + */ +public class InitializingInterceptorRegistry extends InterceptorRegistry { + + @Nullable + public RequestInterceptor initRequesterRequestInterceptor(RSocket rSocketRequester) { + return CompositeRequestInterceptor.create( + getRequestInterceptorsForRequester() + .stream() + .map(factory -> factory.apply(rSocketRequester)) + .collect(Collectors.toList())); + } + + @Nullable + public RequestInterceptor initResponderRequestInterceptor( + RSocket rSocketResponder, RequestInterceptor... perConnectionInterceptors) { + return CompositeRequestInterceptor.create( + Stream.concat( + Stream.of(perConnectionInterceptors), + getRequestInterceptorsForResponder() + .stream() + .map(inteptorFactory -> inteptorFactory.apply(rSocketResponder))) + .collect(Collectors.toList())); + } + + public DuplexConnection initConnection( + DuplexConnectionInterceptor.Type type, DuplexConnection connection) { + for (DuplexConnectionInterceptor interceptor : getConnectionInterceptors()) { + connection = interceptor.apply(type, connection); + } + return connection; + } + + public RSocket initRequester(RSocket rsocket) { + for (RSocketInterceptor interceptor : getRequesterInterceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public RSocket initResponder(RSocket rsocket) { + for (RSocketInterceptor interceptor : getResponderInterceptors()) { + rsocket = interceptor.apply(rsocket); + } + return rsocket; + } + + public SocketAcceptor initSocketAcceptor(SocketAcceptor acceptor) { + for (SocketAcceptorInterceptor interceptor : getSocketAcceptorInterceptors()) { + acceptor = interceptor.apply(acceptor); + } + return acceptor; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java new file mode 100644 index 000000000..680fb514f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/InterceptorRegistry.java @@ -0,0 +1,160 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.RSocket; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * Provides support for registering interceptors at the following levels: + * + *

    + *
  • {@link #forConnection(DuplexConnectionInterceptor)} -- transport level + *
  • {@link #forSocketAcceptor(SocketAcceptorInterceptor)} -- for accepting new connections + *
  • {@link #forRequester(RSocketInterceptor)} -- for performing of requests + *
  • {@link #forResponder(RSocketInterceptor)} -- for responding to requests + *
+ */ +public class InterceptorRegistry { + private List> requesterRequestInterceptors = + new ArrayList<>(); + private List> responderRequestInterceptors = + new ArrayList<>(); + private List requesterRSocketInterceptors = new ArrayList<>(); + private List responderRSocketInterceptors = new ArrayList<>(); + private List socketAcceptorInterceptors = new ArrayList<>(); + private List connectionInterceptors = new ArrayList<>(); + + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forRequestsInRequester( + Function interceptor) { + requesterRequestInterceptors.add(interceptor); + return this; + } + + /** + * Add an {@link RequestInterceptor} that will hook into Requester RSocket requests' phases. + * + * @param interceptor a function which accepts an {@link RSocket} and returns a new {@link + * RequestInterceptor} + * @since 1.1 + */ + public InterceptorRegistry forRequestsInResponder( + Function interceptor) { + responderRequestInterceptors.add(interceptor); + return this; + } + + /** + * Add an {@link RSocketInterceptor} that will decorate the RSocket used for performing requests. + */ + public InterceptorRegistry forRequester(RSocketInterceptor interceptor) { + requesterRSocketInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forRequester(RSocketInterceptor)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forRequester(Consumer> consumer) { + consumer.accept(requesterRSocketInterceptors); + return this; + } + + /** + * Add an {@link RSocketInterceptor} that will decorate the RSocket used for resonding to + * requests. + */ + public InterceptorRegistry forResponder(RSocketInterceptor interceptor) { + responderRSocketInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forResponder(RSocketInterceptor)} with access to the list of existing + * registrations. + */ + public InterceptorRegistry forResponder(Consumer> consumer) { + consumer.accept(responderRSocketInterceptors); + return this; + } + + /** + * Add a {@link SocketAcceptorInterceptor} that will intercept the accepting of new connections. + */ + public InterceptorRegistry forSocketAcceptor(SocketAcceptorInterceptor interceptor) { + socketAcceptorInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forSocketAcceptor(SocketAcceptorInterceptor)} with access to the list of + * existing registrations. + */ + public InterceptorRegistry forSocketAcceptor(Consumer> consumer) { + consumer.accept(socketAcceptorInterceptors); + return this; + } + + /** Add a {@link DuplexConnectionInterceptor}. */ + public InterceptorRegistry forConnection(DuplexConnectionInterceptor interceptor) { + connectionInterceptors.add(interceptor); + return this; + } + + /** + * Variant of {@link #forConnection(DuplexConnectionInterceptor)} with access to the list of + * existing registrations. + */ + public InterceptorRegistry forConnection(Consumer> consumer) { + consumer.accept(connectionInterceptors); + return this; + } + + List> getRequestInterceptorsForRequester() { + return requesterRequestInterceptors; + } + + List> getRequestInterceptorsForResponder() { + return responderRequestInterceptors; + } + + List getRequesterInterceptors() { + return requesterRSocketInterceptors; + } + + List getResponderInterceptors() { + return responderRSocketInterceptors; + } + + List getConnectionInterceptors() { + return connectionInterceptors; + } + + List getSocketAcceptorInterceptors() { + return socketAcceptorInterceptors; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java new file mode 100644 index 000000000..d7d9742d0 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/LimitRateInterceptor.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.RSocketProxy; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +/** + * Interceptor that adds {@link Flux#limitRate(int, int)} to publishers of outbound streams that + * breaks down or aggregates demand values from the remote end (i.e. {@code REQUEST_N} frames) into + * batches of a uniform size. For example the remote may request {@code Long.MAXVALUE} or it may + * start requesting one at a time, in both cases with the limit set to 64, the publisher will see a + * demand of 64 to start and subsequent batches of 48, i.e. continuing to prefetch and refill an + * internal queue when it falls to 75% full. The high and low tide marks are configurable. + * + *

See static factory methods to create an instance for a requester or for a responder. + * + *

Note: keep in mind that the {@code limitRate} operator always uses requests + * the same request values, even if the remote requests less than the limit. For example given a + * limit of 64, if the remote requests 4, 64 will be prefetched of which 4 will be sent and 60 will + * be cached. + * + * @since 1.0 + */ +public class LimitRateInterceptor implements RSocketInterceptor { + + private final int highTide; + private final int lowTide; + private final boolean requesterProxy; + + private LimitRateInterceptor(int highTide, int lowTide, boolean requesterProxy) { + this.highTide = highTide; + this.lowTide = lowTide; + this.requesterProxy = requesterProxy; + } + + @Override + public RSocket apply(RSocket socket) { + return requesterProxy ? new RequesterProxy(socket) : new ResponderProxy(socket); + } + + /** + * Create an interceptor for an {@code RSocket} that handles request-stream and/or request-channel + * interactions. + * + * @param prefetchRate the prefetch rate to pass to {@link Flux#limitRate(int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forResponder(int prefetchRate) { + return forResponder(prefetchRate, prefetchRate); + } + + /** + * Create an interceptor for an {@code RSocket} that handles request-stream and/or request-channel + * interactions with more control over the overall prefetch rate and replenish threshold. + * + * @param highTide the high tide value to pass to {@link Flux#limitRate(int, int)} + * @param lowTide the low tide value to pass to {@link Flux#limitRate(int, int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forResponder(int highTide, int lowTide) { + return new LimitRateInterceptor(highTide, lowTide, false); + } + + /** + * Create an interceptor for an {@code RSocket} that performs request-channel interactions. + * + * @param prefetchRate the prefetch rate to pass to {@link Flux#limitRate(int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forRequester(int prefetchRate) { + return forRequester(prefetchRate, prefetchRate); + } + + /** + * Create an interceptor for an {@code RSocket} that performs request-channel interactions with + * more control over the overall prefetch rate and replenish threshold. + * + * @param highTide the high tide value to pass to {@link Flux#limitRate(int, int)} + * @param lowTide the low tide value to pass to {@link Flux#limitRate(int, int)} + * @return the created interceptor + */ + public static LimitRateInterceptor forRequester(int highTide, int lowTide) { + return new LimitRateInterceptor(highTide, lowTide, true); + } + + /** Responder side proxy, limits response streams. */ + private class ResponderProxy extends RSocketProxy { + + ResponderProxy(RSocket source) { + super(source); + } + + @Override + public Flux requestStream(Payload payload) { + return super.requestStream(payload).limitRate(highTide, lowTide); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return super.requestChannel(payloads).limitRate(highTide, lowTide); + } + } + + /** Requester side proxy, limits channel request stream. */ + private class RequesterProxy extends RSocketProxy { + + RequesterProxy(RSocket source) { + super(source); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return super.requestChannel(Flux.from(payloads).limitRate(highTide, lowTide)); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java new file mode 100644 index 000000000..0cd4bb8f6 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RSocketInterceptor.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.plugins; + +import io.rsocket.RSocket; +import java.util.function.Function; + +/** + * Contract to decorate an {@link RSocket}, providing a way to intercept interactions. This can be + * applied to a {@link InterceptorRegistry#forRequester(RSocketInterceptor) requester} or {@link + * InterceptorRegistry#forResponder(RSocketInterceptor) responder} {@code RSocket} of a client or + * server. + */ +public @FunctionalInterface interface RSocketInterceptor extends Function {} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java new file mode 100644 index 000000000..08131b39d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/RequestInterceptor.java @@ -0,0 +1,79 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import reactor.core.Disposable; +import reactor.util.annotation.Nullable; +import reactor.util.context.Context; + +/** + * Class used to track the RSocket requests lifecycles. The main difference and advantage of this + * interceptor compares to {@link RSocketInterceptor} is that it allows intercepting the initial and + * terminal phases on every individual request. + * + *

Note, if any of the invocations will rise a runtime exception, this exception will be + * caught and be propagated to {@link reactor.core.publisher.Operators#onErrorDropped(Throwable, + * Context)} + * + * @since 1.1 + */ +public interface RequestInterceptor extends Disposable { + + /** + * Method which is being invoked on successful acceptance and start of a request. + * + * @param streamId used for the request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata); + + /** + * Method which is being invoked once a successfully accepted request is terminated. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onCancel(int, FrameType)}. + * + * @param streamId used by this request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param t with which this finished has terminated. Must be one of the following signals + */ + void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t); + + /** + * Method which is being invoked once a successfully accepted request is cancelled. This method + * can be invoked only after the {@link #onStart(int, FrameType, ByteBuf)} method. This method is + * exclusive with {@link #onTerminate(int, FrameType, Throwable)}. + * + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param streamId used by this request + */ + void onCancel(int streamId, FrameType requestType); + + /** + * Method which is being invoked on the request rejection. This method is being called only if the + * actual request can not be started and is called instead of the {@link #onStart(int, FrameType, + * ByteBuf)} method. The reason for rejection can be one of the following: + * + *

+ * + *

    + *
  • No available {@link io.rsocket.lease.Lease} on the requester or the responder sides + *
  • Invalid {@link io.rsocket.Payload} size or format on the Requester side, so the request + * is being rejected before the actual streamId is generated + *
  • A second subscription on the ongoing Request + *
+ * + * @param rejectionReason exception which causes rejection of a particular request + * @param requestType of the request. Must be one of the following types {@link + * FrameType#REQUEST_FNF}, {@link FrameType#REQUEST_RESPONSE}, {@link + * FrameType#REQUEST_STREAM} or {@link FrameType#REQUEST_CHANNEL} + * @param metadata taken from the initial frame + */ + void onReject(Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata); +} diff --git a/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java new file mode 100644 index 000000000..6dd850ba9 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/plugins/SocketAcceptorInterceptor.java @@ -0,0 +1,29 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.plugins; + +import io.rsocket.SocketAcceptor; +import java.util.function.Function; + +/** + * Contract to decorate a {@link SocketAcceptor}, providing access to connection {@code setup} + * information and the ability to also decorate the sockets for requesting and responding. + * + *

This could be used as an alternative to registering an individual "requester" {@code + * RSocketInterceptor} and "responder" {@code RSocketInterceptor}. + */ +public @FunctionalInterface interface SocketAcceptorInterceptor + extends Function {} diff --git a/src/main/java/io/reactivesocket/exceptions/RejectedSetupException.java b/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java similarity index 64% rename from src/main/java/io/reactivesocket/exceptions/RejectedSetupException.java rename to rsocket-core/src/main/java/io/rsocket/plugins/package-info.java index de7899dce..fd9e1f01a 100644 --- a/src/main/java/io/reactivesocket/exceptions/RejectedSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/plugins/package-info.java @@ -1,11 +1,11 @@ -/** - * Copyright 2015 Netflix, Inc. +/* + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket.exceptions; -public class RejectedSetupException extends SetupException implements Retryable { - public RejectedSetupException(String message) { - super(message); - } -} +/** Contracts for interception of transports, connections, and requests in in RSocket Java. */ +@NonNullApi +package io.rsocket.plugins; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java new file mode 100644 index 000000000..ca4f5dcb4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientRSocketSession.java @@ -0,0 +1,383 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.Function; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.function.Tuple2; +import reactor.util.retry.Retry; + +public class ClientRSocketSession + implements RSocketSession, + ResumeStateHolder, + CoreSubscriber> { + + private static final Logger logger = LoggerFactory.getLogger(ClientRSocketSession.class); + + final ResumableDuplexConnection resumableConnection; + final Mono> connectionFactory; + final ResumableFramesStore resumableFramesStore; + + final ByteBufAllocator allocator; + final Duration resumeSessionDuration; + final Retry retry; + final boolean cleanupStoreOnKeepAlive; + final ByteBuf resumeToken; + final String session; + final Disposable reconnectDisposable; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(ClientRSocketSession.class, Subscription.class, "s"); + + KeepAliveSupport keepAliveSupport; + + public ClientRSocketSession( + ByteBuf resumeToken, + ResumableDuplexConnection resumableDuplexConnection, + Mono connectionFactory, + Function>> connectionTransformer, + ResumableFramesStore resumableFramesStore, + Duration resumeSessionDuration, + Retry retry, + boolean cleanupStoreOnKeepAlive) { + this.resumeToken = resumeToken; + this.session = resumeToken.toString(CharsetUtil.UTF_8); + this.connectionFactory = + connectionFactory + .doOnDiscard( + DuplexConnection.class, + c -> { + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server=[Session Expired]"); + c.sendErrorAndClose(connectionErrorException); + c.receive().subscribe(); + }) + .flatMap( + dc -> { + final long impliedPosition = resumableFramesStore.frameImpliedPosition(); + final long position = resumableFramesStore.framePosition(); + dc.sendFrame( + 0, + ResumeFrameCodec.encode( + dc.alloc(), + resumeToken.retain(), + // server uses this to release its cache + impliedPosition, // observed on the client side + // server uses this to check whether there is no mismatch + position // sent from the client sent + )); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. ResumeFrame[impliedPosition[{}], position[{}]] has been sent.", + session, + impliedPosition, + position); + } + + return connectionTransformer.apply(dc); + }) + .doOnDiscard(Tuple2.class, this::tryReestablishSession); + this.resumableFramesStore = resumableFramesStore; + this.allocator = resumableDuplexConnection.alloc(); + this.resumeSessionDuration = resumeSessionDuration; + this.retry = retry; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + this.resumableConnection = resumableDuplexConnection; + + resumableDuplexConnection.onClose().doFinally(__ -> dispose()).subscribe(); + + this.reconnectDisposable = + resumableDuplexConnection.onActiveConnectionClosed().subscribe(this::reconnect); + } + + void reconnect(int index) { + if (this.s == Operators.cancelledSubscription()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Connection[{}] is lost. Reconnecting rejected since session is closed", + session, + index); + } + return; + } + + keepAliveSupport.stop(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Connection[{}] is lost. Reconnecting to resume...", + session, + index); + } + connectionFactory + .doOnNext(this::tryReestablishSession) + .retryWhen(retry) + .timeout(resumeSessionDuration) + .subscribe(this); + } + + @Override + public long impliedPosition() { + return resumableFramesStore.frameImpliedPosition(); + } + + @Override + public void onImpliedPosition(long remoteImpliedPos) { + if (cleanupStoreOnKeepAlive) { + try { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } catch (Throwable e) { + resumableConnection.sendErrorAndClose(new ConnectionErrorException(e.getMessage(), e)); + } + } + } + + @Override + public void dispose() { + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Disposing", session); + } + + boolean result = Operators.terminate(S, this); + + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Sessions[isDisposed={}]", session, result); + } + + reconnectDisposable.dispose(); + resumableConnection.dispose(); + // frame store is disposed by resumable connection + // resumableFramesStore.dispose(); + + if (resumeToken.refCnt() > 0) { + resumeToken.release(); + } + } + + @Override + public boolean isDisposed() { + return resumableConnection.isDisposed(); + } + + void tryReestablishSession(Tuple2 tuple2) { + if (logger.isDebugEnabled()) { + logger.debug("Active subscription is canceled {}", s == Operators.cancelledSubscription()); + } + ByteBuf shouldBeResumeOKFrame = tuple2.getT1(); + DuplexConnection nextDuplexConnection = tuple2.getT2(); + + final int streamId = FrameHeaderCodec.streamId(shouldBeResumeOKFrame); + if (streamId != 0) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Illegal first frame received. RESUME_OK frame must be received before any others. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("RESUME_OK frame must be received before any others"); + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + + throw connectionErrorException; // throw to retry connection again + } + + final FrameType frameType = FrameHeaderCodec.nativeFrameType(shouldBeResumeOKFrame); + if (frameType == FrameType.RESUME_OK) { + // how many frames the server has received from the client + // so the client can release cached frames by this point + long remoteImpliedPos = ResumeOkFrameCodec.lastReceivedClientPos(shouldBeResumeOKFrame); + // what was the last notification from the server about number of frames being + // observed + final long position = resumableFramesStore.framePosition(); + final long impliedPosition = resumableFramesStore.frameImpliedPosition(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. ResumeOK FRAME received. ServerResumeState[remoteImpliedPosition[{}]]. ClientResumeState[impliedPosition[{}], position[{}]]", + session, + remoteImpliedPos, + impliedPosition, + position); + } + if (position <= remoteImpliedPos) { + try { + if (position != remoteImpliedPos) { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } + } catch (IllegalStateException e) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Exception occurred while releasing frames in the frameStore", + session, + e); + } + final ConnectionErrorException t = new ConnectionErrorException(e.getMessage(), e); + + resumableConnection.dispose(nextDuplexConnection, t); + + nextDuplexConnection.sendErrorAndClose(t); + nextDuplexConnection.receive().subscribe(); + + return; + } + + if (!tryCancelSessionTimeout()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server=[Session Expired]"); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + return; + } + + keepAliveSupport.start(); + + if (logger.isDebugEnabled()) { + logger.debug("Side[client]|Session[{}]. Session has been resumed successfully", session); + } + + if (!resumableConnection.connect(nextDuplexConnection)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server_pos=[Session Expired]"); + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + // no need to do anything since connection resumable connection is liklly to + // be disposed + } + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Mismatching remote and local state. Expected RemoteImpliedPosition[{}] to be greater or equal to the LocalPosition[{}]. Terminating received connection", + session, + remoteImpliedPos, + position); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("resumption_server_pos=[" + remoteImpliedPos + "]"); + + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + } + } else if (frameType == FrameType.ERROR) { + final RuntimeException exception = Exceptions.from(0, shouldBeResumeOKFrame); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Received error frame. Terminating received connection", + session, + exception); + } + if (exception instanceof RejectedResumeException) { + resumableConnection.dispose(nextDuplexConnection, exception); + nextDuplexConnection.dispose(); + nextDuplexConnection.receive().subscribe(); + return; + } + + nextDuplexConnection.dispose(); + nextDuplexConnection.receive().subscribe(); + throw exception; // assume retryable exception + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[client]|Session[{}]. Illegal first frame received. RESUME_OK frame must be received before any others. Terminating received connection", + session); + } + final ConnectionErrorException connectionErrorException = + new ConnectionErrorException("RESUME_OK frame must be received before any others"); + + resumableConnection.dispose(nextDuplexConnection, connectionErrorException); + + nextDuplexConnection.sendErrorAndClose(connectionErrorException); + nextDuplexConnection.receive().subscribe(); + + // no need to do anything since remote server rejected our connection completely + } + } + + boolean tryCancelSessionTimeout() { + for (; ; ) { + final Subscription subscription = this.s; + + if (subscription == Operators.cancelledSubscription()) { + return false; + } + + if (S.compareAndSet(this, subscription, null)) { + subscription.cancel(); + return true; + } + } + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(Tuple2 objects) {} + + @Override + public void onError(Throwable t) { + if (!Operators.terminate(S, this)) { + Operators.onErrorDropped(t, currentContext()); + } + + resumableConnection.dispose(); + } + + @Override + public void onComplete() {} + + public void setKeepAliveSupport(KeepAliveSupport keepAliveSupport) { + this.keepAliveSupport = keepAliveSupport; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ClientResume.java b/rsocket-core/src/main/java/io/rsocket/resume/ClientResume.java new file mode 100644 index 000000000..415a77f92 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ClientResume.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import java.time.Duration; + +public class ClientResume { + private final Duration sessionDuration; + private final ByteBuf resumeToken; + + public ClientResume(Duration sessionDuration, ByteBuf resumeToken) { + this.sessionDuration = sessionDuration; + this.resumeToken = resumeToken; + } + + public Duration sessionDuration() { + return sessionDuration; + } + + public ByteBuf resumeToken() { + return resumeToken; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java new file mode 100644 index 000000000..e23bc154b --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/InMemoryResumableFramesStore.java @@ -0,0 +1,854 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import static io.rsocket.resume.ResumableDuplexConnection.isResumableFrame; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +/** + * writes - n (where n is frequent, primary operation) reads - m (where m == KeepAliveFrequency) + * skip - k -> 0 (where k is the rare operation which happens after disconnection + */ +public class InMemoryResumableFramesStore extends Flux + implements ResumableFramesStore, Subscription { + + private FramesSubscriber framesSubscriber; + private static final Logger logger = LoggerFactory.getLogger(InMemoryResumableFramesStore.class); + + final Sinks.Empty disposed = Sinks.empty(); + final Queue cachedFrames; + final String side; + final String session; + final int cacheLimit; + + volatile long impliedPosition; + static final AtomicLongFieldUpdater IMPLIED_POSITION = + AtomicLongFieldUpdater.newUpdater(InMemoryResumableFramesStore.class, "impliedPosition"); + + volatile long firstAvailableFramePosition; + static final AtomicLongFieldUpdater FIRST_AVAILABLE_FRAME_POSITION = + AtomicLongFieldUpdater.newUpdater( + InMemoryResumableFramesStore.class, "firstAvailableFramePosition"); + + long remoteImpliedPosition; + + int cacheSize; + + Throwable terminal; + + CoreSubscriber actual; + CoreSubscriber pendingActual; + + volatile long state; + static final AtomicLongFieldUpdater STATE = + AtomicLongFieldUpdater.newUpdater(InMemoryResumableFramesStore.class, "state"); + + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is finalized and all related + * stores are cleaned + */ + static final long FINALIZED_FLAG = + 0b1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is terminated via the {@link + * InMemoryResumableFramesStore#dispose()} method + */ + static final long DISPOSED_FLAG = + 0b0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} is terminated via the {@link + * FramesSubscriber#onComplete()} or {@link FramesSubscriber#onError(Throwable)} ()} methods + */ + static final long TERMINATED_FLAG = + 0b0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** Flag which indicates that {@link InMemoryResumableFramesStore} has active frames consumer */ + static final long CONNECTED_FLAG = + 0b0001_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore} has no active frames consumer + * but there is a one pending + */ + static final long PENDING_CONNECTION_FLAG = + 0b0000_1000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that there are some received implied position changes from the remote + * party + */ + static final long REMOTE_IMPLIED_POSITION_CHANGED_FLAG = + 0b0000_0100_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that there are some frames stored in the {@link + * io.rsocket.internal.UnboundedProcessor} which has to be cached and sent to the remote party + */ + static final long HAS_FRAME_FLAG = + 0b0000_0010_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000_0000L; + /** + * Flag which indicates that {@link InMemoryResumableFramesStore#drain(long)} has an actor which + * is currently progressing on the work. This flag should work as a guard to enter|exist into|from + * the {@link InMemoryResumableFramesStore#drain(long)} method. + */ + static final long MAX_WORK_IN_PROGRESS = + 0b0000_0000_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111_1111L; + + public InMemoryResumableFramesStore(String side, ByteBuf session, int cacheSizeBytes) { + this.side = side; + this.session = session.toString(CharsetUtil.UTF_8); + this.cacheLimit = cacheSizeBytes; + this.cachedFrames = new ArrayDeque<>(); + } + + public Mono saveFrames(Flux frames) { + return frames + .transform( + Operators.lift( + (__, actual) -> this.framesSubscriber = new FramesSubscriber(actual, this))) + .then(); + } + + @Override + public void releaseFrames(long remoteImpliedPos) { + long lastReceivedRemoteImpliedPosition = this.remoteImpliedPosition; + if (lastReceivedRemoteImpliedPosition > remoteImpliedPos) { + throw new IllegalStateException( + "Given Remote Implied Position is behind the last received Remote Implied Position"); + } + + this.remoteImpliedPosition = remoteImpliedPos; + + final long previousState = markRemoteImpliedPositionChanged(this); + if (isFinalized(previousState) || isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | REMOTE_IMPLIED_POSITION_CHANGED_FLAG); + } + + void drain(long expectedState) { + final Fuseable.QueueSubscription qs = this.framesSubscriber.qs; + final Queue cachedFrames = this.cachedFrames; + + for (; ; ) { + if (hasRemoteImpliedPositionChanged(expectedState)) { + expectedState = handlePendingRemoteImpliedPositionChanges(expectedState, cachedFrames); + } + + if (hasPendingConnection(expectedState)) { + expectedState = handlePendingConnection(expectedState, cachedFrames); + } + + if (isConnected(expectedState)) { + if (isTerminated(expectedState)) { + handleTerminated(qs, this.terminal); + } else if (isDisposed()) { + handleDisposed(); + } else if (hasFrames(expectedState)) { + handlePendingFrames(qs); + } + } + + if (isDisposed(expectedState) || isTerminated(expectedState)) { + clearAndFinalize(this); + return; + } + + expectedState = markWorkDone(this, expectedState); + if (isFinalized(expectedState)) { + return; + } + + if (!isWorkInProgress(expectedState)) { + return; + } + } + } + + long handlePendingRemoteImpliedPositionChanges(long expectedState, Queue cachedFrames) { + final long remoteImpliedPosition = this.remoteImpliedPosition; + final long firstAvailableFramePosition = this.firstAvailableFramePosition; + final long toDropFromCache = Math.max(0, remoteImpliedPosition - firstAvailableFramePosition); + + if (toDropFromCache > 0) { + final int droppedFromCache = dropFramesFromCache(toDropFromCache, cachedFrames); + + if (toDropFromCache > droppedFromCache) { + this.terminal = + new IllegalStateException( + String.format( + "Local and remote state disagreement: " + + "need to remove additional %d bytes, but cache is empty", + toDropFromCache)); + expectedState = markTerminated(this) | TERMINATED_FLAG; + } + + if (toDropFromCache < droppedFromCache) { + this.terminal = + new IllegalStateException( + "Local and remote state disagreement: local and remote frame sizes are not equal"); + expectedState = markTerminated(this) | TERMINATED_FLAG; + } + + FIRST_AVAILABLE_FRAME_POSITION.lazySet(this, firstAvailableFramePosition + droppedFromCache); + if (this.cacheLimit != Integer.MAX_VALUE) { + this.cacheSize -= droppedFromCache; + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Removed frames from cache to position[{}]. CacheSize[{}]", + this.side, + this.session, + this.remoteImpliedPosition, + this.cacheSize); + } + } + } + + return expectedState; + } + + void handlePendingFrames(Fuseable.QueueSubscription qs) { + for (; ; ) { + final ByteBuf frame = qs.poll(); + final boolean empty = frame == null; + + if (empty) { + break; + } + + handleFrame(frame); + + if (!isConnected(this.state)) { + break; + } + } + } + + long handlePendingConnection(long expectedState, Queue cachedFrames) { + CoreSubscriber lastActual = null; + for (; ; ) { + final CoreSubscriber nextActual = this.pendingActual; + + if (nextActual != lastActual) { + for (final ByteBuf frame : cachedFrames) { + nextActual.onNext(frame.retainedSlice()); + } + } + + expectedState = markConnected(this, expectedState); + if (isConnected(expectedState)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Connected at Position[{}] and ImpliedPosition[{}]", + side, + session, + firstAvailableFramePosition, + impliedPosition); + } + + this.actual = nextActual; + break; + } + + if (!hasPendingConnection(expectedState)) { + break; + } + + lastActual = nextActual; + } + return expectedState; + } + + static int dropFramesFromCache(long toRemoveBytes, Queue cache) { + int removedBytes = 0; + while (toRemoveBytes > removedBytes && cache.size() > 0) { + final ByteBuf cachedFrame = cache.poll(); + final int frameSize = cachedFrame.readableBytes(); + + cachedFrame.release(); + + removedBytes += frameSize; + } + + return removedBytes; + } + + @Override + public Flux resumeStream() { + return this; + } + + @Override + public long framePosition() { + return this.firstAvailableFramePosition; + } + + @Override + public long frameImpliedPosition() { + return this.impliedPosition & Long.MAX_VALUE; + } + + @Override + public boolean resumableFrameReceived(ByteBuf frame) { + final int frameSize = frame.readableBytes(); + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + if (impliedPosition < 0) { + return false; + } + + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, impliedPosition + frameSize)) { + return true; + } + } + } + + void pauseImplied() { + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, impliedPosition | Long.MIN_VALUE)) { + logger.debug( + "Side[{}]|Session[{}]. Paused at position[{}]", side, session, impliedPosition); + return; + } + } + } + + void resumeImplied() { + for (; ; ) { + final long impliedPosition = this.impliedPosition; + + final long restoredImpliedPosition = impliedPosition & Long.MAX_VALUE; + if (IMPLIED_POSITION.compareAndSet(this, impliedPosition, restoredImpliedPosition)) { + logger.debug( + "Side[{}]|Session[{}]. Resumed at position[{}]", + side, + session, + restoredImpliedPosition); + return; + } + } + } + + @Override + public Mono onClose() { + return disposed.asMono(); + } + + @Override + public void dispose() { + final long previousState = markDisposed(this); + if (isFinalized(previousState) + || isDisposed(previousState) + || isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | DISPOSED_FLAG); + } + + void clearCache() { + final Queue frames = this.cachedFrames; + this.cacheSize = 0; + + ByteBuf frame; + while ((frame = frames.poll()) != null) { + frame.release(); + } + } + + @Override + public boolean isDisposed() { + return isDisposed(this.state); + } + + void handleFrame(ByteBuf frame) { + final boolean isResumable = isResumableFrame(frame); + if (isResumable) { + handleResumableFrame(frame); + return; + } + + handleConnectionFrame(frame); + } + + void handleTerminated(Fuseable.QueueSubscription qs, @Nullable Throwable t) { + for (; ; ) { + final ByteBuf frame = qs.poll(); + final boolean empty = frame == null; + + if (empty) { + break; + } + + handleFrame(frame); + } + if (t != null) { + this.actual.onError(t); + } else { + this.actual.onComplete(); + } + } + + void handleDisposed() { + this.actual.onError(new CancellationException("Disposed")); + } + + void handleConnectionFrame(ByteBuf frame) { + this.actual.onNext(frame); + } + + void handleResumableFrame(ByteBuf frame) { + final Queue frames = this.cachedFrames; + final int incomingFrameSize = frame.readableBytes(); + final int cacheLimit = this.cacheLimit; + + final boolean canBeStore; + int cacheSize = this.cacheSize; + if (cacheLimit != Integer.MAX_VALUE) { + final long availableSize = cacheLimit - cacheSize; + + if (availableSize < incomingFrameSize) { + final long firstAvailableFramePosition = this.firstAvailableFramePosition; + final long toRemoveBytes = incomingFrameSize - availableSize; + final int removedBytes = dropFramesFromCache(toRemoveBytes, frames); + + cacheSize = cacheSize - removedBytes; + canBeStore = removedBytes >= toRemoveBytes; + + if (canBeStore) { + FIRST_AVAILABLE_FRAME_POSITION.lazySet(this, firstAvailableFramePosition + removedBytes); + } else { + this.cacheSize = cacheSize; + FIRST_AVAILABLE_FRAME_POSITION.lazySet( + this, firstAvailableFramePosition + removedBytes + incomingFrameSize); + } + } else { + canBeStore = true; + } + } else { + canBeStore = true; + } + + if (canBeStore) { + frames.offer(frame); + + if (cacheLimit != Integer.MAX_VALUE) { + this.cacheSize = cacheSize + incomingFrameSize; + } + } + + this.actual.onNext(canBeStore ? frame.retainedSlice() : frame); + } + + @Override + public void request(long n) {} + + @Override + public void cancel() { + pauseImplied(); + markDisconnected(this); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]. Disconnected at Position[{}] and ImpliedPosition[{}]", + side, + session, + firstAvailableFramePosition, + frameImpliedPosition()); + } + } + + @Override + public void subscribe(CoreSubscriber actual) { + resumeImplied(); + actual.onSubscribe(this); + this.pendingActual = actual; + + final long previousState = markPendingConnection(this); + if (isDisposed(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (isTerminated(previousState)) { + actual.onError(new CancellationException("Disposed")); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + drain((previousState + 1) | PENDING_CONNECTION_FLAG); + } + + static class FramesSubscriber + implements CoreSubscriber, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + final InMemoryResumableFramesStore parent; + + Fuseable.QueueSubscription qs; + + boolean done; + + FramesSubscriber(CoreSubscriber actual, InMemoryResumableFramesStore parent) { + this.actual = actual; + this.parent = parent; + } + + @Override + @SuppressWarnings("unchecked") + public void onSubscribe(Subscription s) { + if (Operators.validate(this.qs, s)) { + final Fuseable.QueueSubscription qs = (Fuseable.QueueSubscription) s; + this.qs = qs; + + final int m = qs.requestFusion(Fuseable.ANY); + + if (m != Fuseable.ASYNC) { + s.cancel(); + this.actual.onSubscribe(this); + this.actual.onError(new IllegalStateException("Source has to be ASYNC fuseable")); + return; + } + + this.actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf byteBuf) { + final InMemoryResumableFramesStore parent = this.parent; + long previousState = InMemoryResumableFramesStore.markFrameAdded(parent); + + if (isFinalized(previousState)) { + this.qs.clear(); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + if (isConnected(previousState) || hasPendingConnection(previousState)) { + parent.drain((previousState + 1) | HAS_FRAME_FLAG); + } + } + + @Override + public void onError(Throwable t) { + if (this.done) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + final InMemoryResumableFramesStore parent = this.parent; + + parent.terminal = t; + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + Operators.onErrorDropped(t, this.actual.currentContext()); + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain((previousState + 1) | TERMINATED_FLAG); + } + + @Override + public void onComplete() { + if (this.done) { + return; + } + + final InMemoryResumableFramesStore parent = this.parent; + + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain((previousState + 1) | TERMINATED_FLAG); + } + + @Override + public void cancel() { + if (this.done) { + return; + } + + this.done = true; + + final long previousState = InMemoryResumableFramesStore.markTerminated(parent); + if (isFinalized(previousState)) { + return; + } + + if (isWorkInProgress(previousState)) { + return; + } + + parent.drain(previousState | TERMINATED_FLAG); + } + + @Override + public void request(long n) {} + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public Void poll() { + return null; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return false; + } + + @Override + public void clear() {} + } + + static long markFrameAdded(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + long nextState = state; + if (isConnected(state) || hasPendingConnection(state) || isWorkInProgress(state)) { + nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? nextState : nextState + 1; + } + + if (STATE.compareAndSet(store, state, nextState | HAS_FRAME_FLAG)) { + return state; + } + } + } + + static long markPendingConnection(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state) || isDisposed(state) || isTerminated(state)) { + return state; + } + + if (isConnected(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : state + 1; + if (STATE.compareAndSet(store, state, nextState | PENDING_CONNECTION_FLAG)) { + return state; + } + } + } + + static long markRemoteImpliedPositionChanged(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | REMOTE_IMPLIED_POSITION_CHANGED_FLAG)) { + return state; + } + } + } + + static long markDisconnected(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + if (STATE.compareAndSet(store, state, state & ~CONNECTED_FLAG & ~PENDING_CONNECTION_FLAG)) { + return state; + } + } + } + + static long markWorkDone(InMemoryResumableFramesStore store, long expectedState) { + for (; ; ) { + final long state = store.state; + + if (expectedState != state) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + final long nextState = state & ~MAX_WORK_IN_PROGRESS & ~REMOTE_IMPLIED_POSITION_CHANGED_FLAG; + if (STATE.compareAndSet(store, state, nextState)) { + return nextState; + } + } + } + + static long markConnected(InMemoryResumableFramesStore store, long expectedState) { + for (; ; ) { + final long state = store.state; + + if (state != expectedState) { + return state; + } + + if (isFinalized(state)) { + return state; + } + + final long nextState = state ^ PENDING_CONNECTION_FLAG | CONNECTED_FLAG; + if (STATE.compareAndSet(store, state, nextState)) { + return nextState; + } + } + } + + static long markTerminated(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | TERMINATED_FLAG)) { + return state; + } + } + } + + static long markDisposed(InMemoryResumableFramesStore store) { + for (; ; ) { + final long state = store.state; + + if (isFinalized(state)) { + return state; + } + + final long nextState = + (state & MAX_WORK_IN_PROGRESS) == MAX_WORK_IN_PROGRESS ? state : (state + 1); + if (STATE.compareAndSet(store, state, nextState | DISPOSED_FLAG)) { + return state; + } + } + } + + static void clearAndFinalize(InMemoryResumableFramesStore store) { + final Fuseable.QueueSubscription qs = store.framesSubscriber.qs; + for (; ; ) { + final long state = store.state; + + qs.clear(); + store.clearCache(); + + if (isFinalized(state)) { + return; + } + + if (STATE.compareAndSet(store, state, state | FINALIZED_FLAG & ~MAX_WORK_IN_PROGRESS)) { + store.disposed.tryEmitEmpty(); + store.framesSubscriber.onComplete(); + return; + } + } + } + + static boolean isConnected(long state) { + return (state & CONNECTED_FLAG) == CONNECTED_FLAG; + } + + static boolean hasRemoteImpliedPositionChanged(long state) { + return (state & REMOTE_IMPLIED_POSITION_CHANGED_FLAG) == REMOTE_IMPLIED_POSITION_CHANGED_FLAG; + } + + static boolean hasPendingConnection(long state) { + return (state & PENDING_CONNECTION_FLAG) == PENDING_CONNECTION_FLAG; + } + + static boolean hasFrames(long state) { + return (state & HAS_FRAME_FLAG) == HAS_FRAME_FLAG; + } + + static boolean isTerminated(long state) { + return (state & TERMINATED_FLAG) == TERMINATED_FLAG; + } + + static boolean isDisposed(long state) { + return (state & DISPOSED_FLAG) == DISPOSED_FLAG; + } + + static boolean isFinalized(long state) { + return (state & FINALIZED_FLAG) == FINALIZED_FLAG; + } + + static boolean isWorkInProgress(long state) { + return (state & MAX_WORK_IN_PROGRESS) > 0; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java new file mode 100644 index 000000000..6dd3d5f4d --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/RSocketSession.java @@ -0,0 +1,25 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.rsocket.keepalive.KeepAliveSupport; +import reactor.core.Disposable; + +public interface RSocketSession extends Disposable { + + void setKeepAliveSupport(KeepAliveSupport keepAliveSupport); +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java new file mode 100644 index 000000000..c8811b9b3 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableDuplexConnection.java @@ -0,0 +1,447 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.internal.UnboundedProcessor; +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +public class ResumableDuplexConnection extends Flux + implements DuplexConnection, Subscription { + + static final Logger logger = LoggerFactory.getLogger(ResumableDuplexConnection.class); + + final String side; + final String session; + final ResumableFramesStore resumableFramesStore; + + final UnboundedProcessor savableFramesSender; + final Sinks.Empty onQueueClose; + final Sinks.Empty onLastConnectionClose; + final SocketAddress remoteAddress; + final Sinks.Many onConnectionClosedSink; + + CoreSubscriber receiveSubscriber; + FrameReceivingSubscriber activeReceivingSubscriber; + + volatile int state; + static final AtomicIntegerFieldUpdater STATE = + AtomicIntegerFieldUpdater.newUpdater(ResumableDuplexConnection.class, "state"); + + volatile DuplexConnection activeConnection; + static final AtomicReferenceFieldUpdater + ACTIVE_CONNECTION = + AtomicReferenceFieldUpdater.newUpdater( + ResumableDuplexConnection.class, DuplexConnection.class, "activeConnection"); + + int connectionIndex = 0; + + public ResumableDuplexConnection( + String side, + ByteBuf session, + DuplexConnection initialConnection, + ResumableFramesStore resumableFramesStore) { + this.side = side; + this.session = session.toString(CharsetUtil.UTF_8); + this.onConnectionClosedSink = Sinks.unsafe().many().unicast().onBackpressureBuffer(); + this.resumableFramesStore = resumableFramesStore; + this.onQueueClose = Sinks.unsafe().empty(); + this.onLastConnectionClose = Sinks.unsafe().empty(); + this.savableFramesSender = new UnboundedProcessor(onQueueClose::tryEmitEmpty); + this.remoteAddress = initialConnection.remoteAddress(); + + resumableFramesStore.saveFrames(savableFramesSender).subscribe(); + + ACTIVE_CONNECTION.lazySet(this, initialConnection); + } + + public boolean connect(DuplexConnection nextConnection) { + final DuplexConnection activeConnection = this.activeConnection; + if (activeConnection != DisposedConnection.INSTANCE + && ACTIVE_CONNECTION.compareAndSet(this, activeConnection, nextConnection)) { + + if (!activeConnection.isDisposed()) { + activeConnection.sendErrorAndClose( + new ConnectionErrorException("Connection unexpectedly replaced")); + } + + initConnection(nextConnection); + + return true; + } else { + return false; + } + } + + void initConnection(DuplexConnection nextConnection) { + final int nextConnectionIndex = this.connectionIndex + 1; + final FrameReceivingSubscriber frameReceivingSubscriber = + new FrameReceivingSubscriber(side, resumableFramesStore, receiveSubscriber); + + this.connectionIndex = nextConnectionIndex; + this.activeReceivingSubscriber = frameReceivingSubscriber; + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Connecting", side, session, connectionIndex); + } + + final Disposable resumeStreamSubscription = + resumableFramesStore + .resumeStream() + .subscribe( + f -> nextConnection.sendFrame(FrameHeaderCodec.streamId(f), f), + t -> { + dispose(nextConnection, t); + nextConnection.sendErrorAndClose(new ConnectionErrorException(t.getMessage(), t)); + }, + () -> { + final ConnectionErrorException e = + new ConnectionErrorException("Connection Closed Unexpectedly"); + dispose(nextConnection, e); + nextConnection.sendErrorAndClose(e); + }); + nextConnection.receive().subscribe(frameReceivingSubscriber); + nextConnection + .onClose() + .doFinally( + __ -> { + frameReceivingSubscriber.dispose(); + resumeStreamSubscription.dispose(); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Disconnected", + side, + session, + connectionIndex); + } + Sinks.EmitResult result = onConnectionClosedSink.tryEmitNext(nextConnectionIndex); + if (!result.equals(Sinks.EmitResult.OK)) { + logger.error( + "Side[{}]|Session[{}]|DuplexConnection[{}]. Failed to notify session of closed connection: {}", + side, + session, + connectionIndex, + result); + } + }) + .subscribe(); + } + + public void disconnect() { + final DuplexConnection activeConnection = this.activeConnection; + if (activeConnection != DisposedConnection.INSTANCE && !activeConnection.isDisposed()) { + activeConnection.dispose(); + } + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + savableFramesSender.tryEmitPrioritized(frame); + } else { + savableFramesSender.tryEmitNormal(frame); + } + } + + /** + * Publisher for a sequence of integers starting at 1, with each next number emitted when the + * currently active connection is closed and should be resumed. The Publisher never emits an error + * and completes when the connection is disposed and not resumed. + */ + Flux onActiveConnectionClosed() { + return onConnectionClosedSink.asFlux(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException rSocketErrorException) { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } + + savableFramesSender.tryEmitFinal( + ErrorFrameCodec.encode(activeConnection.alloc(), 0, rSocketErrorException)); + + activeConnection + .onClose() + .subscribe( + null, + t -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }, + () -> { + onConnectionClosedSink.tryEmitComplete(); + + final Throwable cause = rSocketErrorException.getCause(); + if (cause == null) { + onLastConnectionClose.tryEmitEmpty(); + } else { + onLastConnectionClose.tryEmitError(cause); + } + }); + } + + @Override + public Flux receive() { + return this; + } + + @Override + public ByteBufAllocator alloc() { + return activeConnection.alloc(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError( + onQueueClose.asMono(), resumableFramesStore.onClose(), onLastConnectionClose.asMono()); + } + + @Override + public void dispose() { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } + savableFramesSender.onComplete(); + activeConnection + .onClose() + .subscribe( + null, + t -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }, + () -> { + onConnectionClosedSink.tryEmitComplete(); + onLastConnectionClose.tryEmitEmpty(); + }); + } + + void dispose(DuplexConnection nextConnection, @Nullable Throwable e) { + final DuplexConnection activeConnection = + ACTIVE_CONNECTION.getAndSet(this, DisposedConnection.INSTANCE); + if (activeConnection == DisposedConnection.INSTANCE) { + return; + } + savableFramesSender.onComplete(); + nextConnection + .onClose() + .subscribe( + null, + t -> { + if (e != null) { + onLastConnectionClose.tryEmitError(e); + } else { + onLastConnectionClose.tryEmitEmpty(); + } + onConnectionClosedSink.tryEmitComplete(); + }, + () -> { + if (e != null) { + onLastConnectionClose.tryEmitError(e); + } else { + onLastConnectionClose.tryEmitEmpty(); + } + onConnectionClosedSink.tryEmitComplete(); + }); + } + + @Override + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onQueueClose.scan(Scannable.Attr.TERMINATED) + || onQueueClose.scan(Scannable.Attr.CANCELLED); + } + + @Override + public SocketAddress remoteAddress() { + return remoteAddress; + } + + @Override + public void request(long n) { + if (state == 1 && STATE.compareAndSet(this, 1, 2)) { + // happens for the very first time with the initial connection + initConnection(this.activeConnection); + } + } + + @Override + public void cancel() { + dispose(); + } + + @Override + public void subscribe(CoreSubscriber receiverSubscriber) { + if (state == 0 && STATE.compareAndSet(this, 0, 1)) { + receiveSubscriber = receiverSubscriber; + receiverSubscriber.onSubscribe(this); + } + } + + static boolean isResumableFrame(ByteBuf frame) { + return FrameHeaderCodec.streamId(frame) != 0; + } + + @Override + public String toString() { + return "ResumableDuplexConnection{" + + "side='" + + side + + '\'' + + ", session='" + + session + + '\'' + + ", remoteAddress=" + + remoteAddress + + ", state=" + + state + + ", activeConnection=" + + activeConnection + + ", connectionIndex=" + + connectionIndex + + '}'; + } + + private static final class DisposedConnection implements DuplexConnection { + + static final DisposedConnection INSTANCE = new DisposedConnection(); + + private DisposedConnection() {} + + @Override + public void dispose() {} + + @Override + public Mono onClose() { + return Mono.never(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) {} + + @Override + public Flux receive() { + return Flux.never(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) {} + + @Override + public ByteBufAllocator alloc() { + return ByteBufAllocator.DEFAULT; + } + + @Override + @SuppressWarnings("ConstantConditions") + public SocketAddress remoteAddress() { + return null; + } + } + + private static final class FrameReceivingSubscriber + implements CoreSubscriber, Disposable { + + final ResumableFramesStore resumableFramesStore; + final CoreSubscriber actual; + final String tag; + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater( + FrameReceivingSubscriber.class, Subscription.class, "s"); + + boolean cancelled; + + private FrameReceivingSubscriber( + String tag, ResumableFramesStore store, CoreSubscriber actual) { + this.tag = tag; + this.resumableFramesStore = store; + this.actual = actual; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(ByteBuf frame) { + if (cancelled || s == Operators.cancelledSubscription()) { + return; + } + + if (isResumableFrame(frame)) { + if (resumableFramesStore.resumableFrameReceived(frame)) { + actual.onNext(frame); + } + return; + } + + actual.onNext(frame); + } + + @Override + public void onError(Throwable t) { + Operators.set(S, this, Operators.cancelledSubscription()); + } + + @Override + public void onComplete() { + Operators.set(S, this, Operators.cancelledSubscription()); + } + + @Override + public void dispose() { + cancelled = true; + Operators.terminate(S, this); + } + + @Override + public boolean isDisposed() { + return cancelled || s == Operators.cancelledSubscription(); + } + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java new file mode 100644 index 000000000..80d9a36dd --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumableFramesStore.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.rsocket.Closeable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** Store for resumable frames */ +public interface ResumableFramesStore extends Closeable { + + /** + * Save resumable frames for potential resumption + * + * @param frames {@link Flux} of resumable frames + * @return {@link Mono} which completes once all resume frames are written + */ + Mono saveFrames(Flux frames); + + /** Release frames from tail of the store up to remote implied position */ + void releaseFrames(long remoteImpliedPos); + + /** + * @return {@link Flux} of frames from store tail to head. It should terminate with error if + * frames are not continuous + */ + Flux resumeStream(); + + /** @return Local frame position as defined by RSocket protocol */ + long framePosition(); + + /** @return Implied frame position as defined by RSocket protocol */ + long frameImpliedPosition(); + + /** + * Received resumable frame as defined by RSocket protocol. Implementation must increment frame + * implied position + * + * @return {@code true} if information about the frame has been stored + */ + boolean resumableFrameReceived(ByteBuf frame); +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateException.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateException.java new file mode 100644 index 000000000..1fae24b07 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +class ResumeStateException extends RuntimeException { + private static final long serialVersionUID = -5393753463377588732L; + private final long localPos; + private final long localImpliedPos; + private final long remotePos; + private final long remoteImpliedPos; + + public ResumeStateException( + long localPos, long localImpliedPos, long remotePos, long remoteImpliedPos) { + this.localPos = localPos; + this.localImpliedPos = localImpliedPos; + this.remotePos = remotePos; + this.remoteImpliedPos = remoteImpliedPos; + } + + public long getLocalPos() { + return localPos; + } + + public long getLocalImpliedPos() { + return localImpliedPos; + } + + public long getRemotePos() { + return remotePos; + } + + public long getRemoteImpliedPos() { + return remoteImpliedPos; + } +} diff --git a/src/main/java/io/reactivesocket/exceptions/InvalidSetupException.java b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateHolder.java similarity index 66% rename from src/main/java/io/reactivesocket/exceptions/InvalidSetupException.java rename to rsocket-core/src/main/java/io/rsocket/resume/ResumeStateHolder.java index febe536bf..31687a24b 100644 --- a/src/main/java/io/reactivesocket/exceptions/InvalidSetupException.java +++ b/rsocket-core/src/main/java/io/rsocket/resume/ResumeStateHolder.java @@ -1,11 +1,11 @@ -/** - * Copyright 2015 Netflix, Inc. +/* + * Copyright 2015-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,10 +13,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket.exceptions; -public class InvalidSetupException extends SetupException { - public InvalidSetupException(String message) { - super(message); - } +package io.rsocket.resume; + +public interface ResumeStateHolder { + + long impliedPosition(); + + void onImpliedPosition(long remoteImpliedPos); } diff --git a/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java new file mode 100644 index 000000000..ad1b38375 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/ServerRSocketSession.java @@ -0,0 +1,301 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import java.time.Duration; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.util.concurrent.Queues; + +public class ServerRSocketSession + implements RSocketSession, ResumeStateHolder, CoreSubscriber { + private static final Logger logger = LoggerFactory.getLogger(ServerRSocketSession.class); + + final ResumableDuplexConnection resumableConnection; + final Duration resumeSessionDuration; + final ResumableFramesStore resumableFramesStore; + final String session; + final ByteBufAllocator allocator; + final boolean cleanupStoreOnKeepAlive; + + /** + * All incoming connections with the Resume intent are enqueued in this queue. Such an approach + * ensure that the new connection will affect the resumption state anyhow until the previous + * (active) connection is finally closed + */ + final Queue connectionsQueue; + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(ServerRSocketSession.class, "wip"); + + volatile Subscription s; + static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(ServerRSocketSession.class, Subscription.class, "s"); + + KeepAliveSupport keepAliveSupport; + + public ServerRSocketSession( + ByteBuf session, + ResumableDuplexConnection resumableDuplexConnection, + DuplexConnection initialDuplexConnection, + ResumableFramesStore resumableFramesStore, + Duration resumeSessionDuration, + boolean cleanupStoreOnKeepAlive) { + this.session = session.toString(CharsetUtil.UTF_8); + this.allocator = initialDuplexConnection.alloc(); + this.resumeSessionDuration = resumeSessionDuration; + this.resumableFramesStore = resumableFramesStore; + this.cleanupStoreOnKeepAlive = cleanupStoreOnKeepAlive; + this.resumableConnection = resumableDuplexConnection; + this.connectionsQueue = Queues.unboundedMultiproducer().get(); + + WIP.lazySet(this, 1); + + resumableDuplexConnection.onClose().doFinally(__ -> dispose()).subscribe(); + resumableDuplexConnection.onActiveConnectionClosed().subscribe(__ -> tryTimeoutSession()); + } + + void tryTimeoutSession() { + keepAliveSupport.stop(); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Connection is lost. Trying to timeout the active session", + session); + } + + Mono.delay(resumeSessionDuration).subscribe(this); + + if (WIP.decrementAndGet(this) == 0) { + return; + } + + final Runnable doResumeRunnable = connectionsQueue.poll(); + if (doResumeRunnable != null) { + doResumeRunnable.run(); + } + } + + public void resumeWith(ByteBuf resumeFrame, DuplexConnection nextDuplexConnection) { + + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. New DuplexConnection received.", session); + } + + long remotePos = ResumeFrameCodec.firstAvailableClientPos(resumeFrame); + long remoteImpliedPos = ResumeFrameCodec.lastReceivedServerPos(resumeFrame); + + connectionsQueue.offer(() -> doResume(remotePos, remoteImpliedPos, nextDuplexConnection)); + + if (WIP.getAndIncrement(this) != 0) { + return; + } + + final Runnable doResumeRunnable = connectionsQueue.poll(); + if (doResumeRunnable != null) { + doResumeRunnable.run(); + } + } + + void doResume(long remotePos, long remoteImpliedPos, DuplexConnection nextDuplexConnection) { + if (!tryCancelSessionTimeout()) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final RejectedResumeException rejectedResumeException = + new RejectedResumeException("resume_internal_error: Session Expired"); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + return; + } + + long impliedPosition = resumableFramesStore.frameImpliedPosition(); + long position = resumableFramesStore.framePosition(); + + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Resume FRAME received. ServerResumeState[impliedPosition[{}], position[{}]]. ClientResumeState[remoteImpliedPosition[{}], remotePosition[{}]]", + session, + impliedPosition, + position, + remoteImpliedPos, + remotePos); + } + + if (remotePos <= impliedPosition && position <= remoteImpliedPos) { + try { + if (position != remoteImpliedPos) { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } + nextDuplexConnection.sendFrame(0, ResumeOkFrameCodec.encode(allocator, impliedPosition)); + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. ResumeOKFrame[impliedPosition[{}]] has been sent", + session, + impliedPosition); + } + } catch (Throwable t) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Exception occurred while releasing frames in the frameStore", + session, + t); + } + + dispose(); + + final RejectedResumeException rejectedResumeException = + new RejectedResumeException(t.getMessage(), t); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + + return; + } + + keepAliveSupport.start(); + + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. Session has been resumed successfully", session); + } + + if (!resumableConnection.connect(nextDuplexConnection)) { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Session has already been expired. Terminating received connection", + session); + } + final RejectedResumeException rejectedResumeException = + new RejectedResumeException("resume_internal_error: Session Expired"); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + + // resumableConnection is likely to be disposed at this stage. Thus we have + // nothing to do + } + } else { + if (logger.isDebugEnabled()) { + logger.debug( + "Side[server]|Session[{}]. Mismatching remote and local state. Expected RemoteImpliedPosition[{}] to be greater or equal to the LocalPosition[{}] and RemotePosition[{}] to be less or equal to LocalImpliedPosition[{}]. Terminating received connection", + session, + remoteImpliedPos, + position, + remotePos, + impliedPosition); + } + + dispose(); + + final RejectedResumeException rejectedResumeException = + new RejectedResumeException( + String.format( + "resumption_pos=[ remote: { pos: %d, impliedPos: %d }, local: { pos: %d, impliedPos: %d }]", + remotePos, remoteImpliedPos, position, impliedPosition)); + nextDuplexConnection.sendErrorAndClose(rejectedResumeException); + nextDuplexConnection.receive().subscribe(); + } + } + + boolean tryCancelSessionTimeout() { + for (; ; ) { + final Subscription subscription = this.s; + + if (subscription == Operators.cancelledSubscription()) { + return false; + } + + if (S.compareAndSet(this, subscription, null)) { + subscription.cancel(); + return true; + } + } + } + + @Override + public long impliedPosition() { + return resumableFramesStore.frameImpliedPosition(); + } + + @Override + public void onImpliedPosition(long remoteImpliedPos) { + if (cleanupStoreOnKeepAlive) { + try { + resumableFramesStore.releaseFrames(remoteImpliedPos); + } catch (Throwable e) { + resumableConnection.sendErrorAndClose(new ConnectionErrorException(e.getMessage(), e)); + } + } + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.setOnce(S, this, s)) { + s.request(Long.MAX_VALUE); + } + } + + @Override + public void onNext(Long aLong) { + if (!Operators.terminate(S, this)) { + return; + } + + resumableConnection.dispose(); + } + + @Override + public void onComplete() {} + + @Override + public void onError(Throwable t) {} + + public void setKeepAliveSupport(KeepAliveSupport keepAliveSupport) { + this.keepAliveSupport = keepAliveSupport; + } + + @Override + public void dispose() { + if (logger.isDebugEnabled()) { + logger.debug("Side[server]|Session[{}]. Disposing session", session); + } + Operators.terminate(S, this); + resumableConnection.dispose(); + } + + @Override + public boolean isDisposed() { + return resumableConnection.isDisposed(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java new file mode 100644 index 000000000..736d7c77c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/SessionManager.java @@ -0,0 +1,70 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.util.annotation.Nullable; + +public class SessionManager { + static final Logger logger = LoggerFactory.getLogger(SessionManager.class); + + private volatile boolean isDisposed; + private final Map sessions = new ConcurrentHashMap<>(); + + public ServerRSocketSession save(ServerRSocketSession session, ByteBuf resumeToken) { + if (isDisposed) { + session.dispose(); + } else { + final String token = resumeToken.toString(CharsetUtil.UTF_8); + session + .resumableConnection + .onClose() + .doFinally( + __ -> { + logger.debug( + "ResumableConnection has been closed. Removing associated session {" + + token + + "}"); + if (isDisposed || sessions.get(token) == session) { + sessions.remove(token); + } + }) + .subscribe(); + ServerRSocketSession prevSession = sessions.remove(token); + if (prevSession != null) { + prevSession.dispose(); + } + sessions.put(token, session); + } + return session; + } + + @Nullable + public ServerRSocketSession get(ByteBuf resumeToken) { + return sessions.get(resumeToken.toString(CharsetUtil.UTF_8)); + } + + public void dispose() { + isDisposed = true; + sessions.values().forEach(ServerRSocketSession::dispose); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/resume/package-info.java b/rsocket-core/src/main/java/io/rsocket/resume/package-info.java new file mode 100644 index 000000000..98744386a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/resume/package-info.java @@ -0,0 +1,27 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Contains support classes for the RSocket resume capability. + * + * @see Resuming + * Operation + */ +@NonNullApi +package io.rsocket.resume; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java b/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java new file mode 100644 index 000000000..3b8f624aa --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/transport/ClientTransport.java @@ -0,0 +1,31 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport; + +import io.rsocket.DuplexConnection; +import reactor.core.publisher.Mono; + +/** A client contract for writing transports of RSocket. */ +public interface ClientTransport extends Transport { + + /** + * Return a {@code Mono} that connects for each subscriber. + * + * @since 1.0.1 + */ + Mono connect(); +} diff --git a/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java b/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java new file mode 100644 index 000000000..92a9502a4 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/transport/ServerTransport.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport; + +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; + +/** A server contract for writing transports of RSocket. */ +public interface ServerTransport extends Transport { + + /** + * Start this server. + * + * @param acceptor to process a newly accepted connections with + * @return A handle for information about and control over the server. + * @since 1.0.1 + */ + Mono start(ConnectionAcceptor acceptor); + + /** A contract to accept a new {@code DuplexConnection}. */ + interface ConnectionAcceptor extends Function> { + + /** + * Accept a new {@code DuplexConnection} and returns {@code Publisher} signifying the end of + * processing of the connection. + * + * @param duplexConnection New {@code DuplexConnection} to be processed. + * @return A {@code Publisher} which terminates when the processing of the connection finishes. + */ + @Override + Mono apply(DuplexConnection duplexConnection); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/transport/Transport.java b/rsocket-core/src/main/java/io/rsocket/transport/Transport.java new file mode 100644 index 000000000..39386337c --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/transport/Transport.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.DuplexConnection; + +/** */ +public interface Transport { + + /** + * Configurations that exposes the maximum frame size that a {@link DuplexConnection} can bring up + * to RSocket level. + * + *

This number should not exist the 16,777,215 (maximum frame size specified by RSocket spec) + * + * @return return maximum configured frame size limit + */ + default int maxFrameLength() { + return FRAME_LENGTH_MASK; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/transport/package-info.java b/rsocket-core/src/main/java/io/rsocket/transport/package-info.java new file mode 100644 index 000000000..00536122a --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/transport/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Client and server transport contracts for pluggable transports. */ +@NonNullApi +package io.rsocket.transport; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java new file mode 100644 index 000000000..12e0b60dc --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/ByteBufPayload.java @@ -0,0 +1,219 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.Recycler; +import io.netty.util.Recycler.Handle; +import io.rsocket.Payload; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import reactor.util.annotation.Nullable; + +public final class ByteBufPayload extends AbstractReferenceCounted implements Payload { + private static final Recycler RECYCLER = + new Recycler() { + protected ByteBufPayload newObject(Handle handle) { + return new ByteBufPayload(handle); + } + }; + + private final Handle handle; + private ByteBuf data; + private ByteBuf metadata; + + private ByteBufPayload(final Handle handle) { + this.handle = handle; + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(String data) { + return create(ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, data), null); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new ByteBufPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(String data, @Nullable String metadata) { + return create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, data), + metadata == null ? null : ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, metadata)); + } + + public static Payload create(CharSequence data, Charset dataCharset) { + return create( + ByteBufUtil.encodeString(ByteBufAllocator.DEFAULT, CharBuffer.wrap(data), dataCharset), + null); + } + + public static Payload create( + CharSequence data, + Charset dataCharset, + @Nullable CharSequence metadata, + Charset metadataCharset) { + return create( + ByteBufUtil.encodeString(ByteBufAllocator.DEFAULT, CharBuffer.wrap(data), dataCharset), + metadata == null + ? null + : ByteBufUtil.encodeString( + ByteBufAllocator.DEFAULT, CharBuffer.wrap(metadata), metadataCharset)); + } + + public static Payload create(byte[] data) { + return create(Unpooled.wrappedBuffer(data), null); + } + + public static Payload create(byte[] data, @Nullable byte[] metadata) { + return create( + Unpooled.wrappedBuffer(data), metadata == null ? null : Unpooled.wrappedBuffer(metadata)); + } + + public static Payload create(ByteBuffer data) { + return create(Unpooled.wrappedBuffer(data), null); + } + + public static Payload create(ByteBuffer data, @Nullable ByteBuffer metadata) { + return create( + Unpooled.wrappedBuffer(data), metadata == null ? null : Unpooled.wrappedBuffer(metadata)); + } + + public static Payload create(ByteBuf data) { + return create(data, null); + } + + public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { + ByteBufPayload payload = RECYCLER.get(); + payload.data = data; + payload.metadata = metadata; + // ensure data and metadata is set before refCnt change + payload.setRefCnt(1); + return payload; + } + + public static Payload create(Payload payload) { + return create( + payload.sliceData().retain(), + payload.hasMetadata() ? payload.sliceMetadata().retain() : null); + } + + @Override + public boolean hasMetadata() { + ensureAccessible(); + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata.slice(); + } + + @Override + public ByteBuf data() { + ensureAccessible(); + return data; + } + + @Override + public ByteBuf metadata() { + ensureAccessible(); + return metadata == null ? Unpooled.EMPTY_BUFFER : metadata; + } + + @Override + public ByteBuf sliceData() { + ensureAccessible(); + return data.slice(); + } + + @Override + public ByteBufPayload retain() { + super.retain(); + return this; + } + + @Override + public ByteBufPayload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public ByteBufPayload touch() { + ensureAccessible(); + data.touch(); + if (metadata != null) { + metadata.touch(); + } + return this; + } + + @Override + public ByteBufPayload touch(Object hint) { + ensureAccessible(); + data.touch(hint); + if (metadata != null) { + metadata.touch(hint); + } + return this; + } + + @Override + protected void deallocate() { + data.release(); + data = null; + if (metadata != null) { + metadata.release(); + metadata = null; + } + handle.recycle(this); + } + + /** + * Should be called by every method that tries to access the buffers content to check if the + * buffer was released before. + */ + void ensureAccessible() { + if (!isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + /** + * Used internally by {@link ByteBufPayload#ensureAccessible()} to try to guard against using the + * buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java new file mode 100644 index 000000000..328fb8435 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/CharByteBufUtil.java @@ -0,0 +1,210 @@ +package io.rsocket.util; + +import static io.netty.util.internal.StringUtil.isSurrogate; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.MathUtil; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.util.Arrays; + +public class CharByteBufUtil { + + private static final byte WRITE_UTF_UNKNOWN = (byte) '?'; + + private CharByteBufUtil() {} + + /** + * Returns the exact bytes length of UTF8 character sequence. + * + *

This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[])}. + */ + public static int utf8Bytes(final char[] seq) { + return utf8ByteCount(seq, 0, seq.length); + } + + /** + * This method is producing the exact length according to {@link #writeUtf8(ByteBuf, char[], int, + * int)}. + */ + public static int utf8Bytes(final char[] seq, int start, int end) { + return utf8ByteCount(checkCharSequenceBounds(seq, start, end), start, end); + } + + private static int utf8ByteCount(final char[] seq, int start, int end) { + int i = start; + // ASCII fast path + while (i < end && seq[i] < 0x80) { + ++i; + } + // !ASCII is packed in a separate method to let the ASCII case be smaller + return i < end ? (i - start) + utf8BytesNonAscii(seq, i, end) : i - start; + } + + private static int utf8BytesNonAscii(final char[] seq, final int start, final int end) { + int encodedLength = 0; + for (int i = start; i < end; i++) { + final char c = seq[i]; + // making it 100% branchless isn't rewarding due to the many bit operations necessary! + if (c < 0x800) { + // branchless version of: (c <= 127 ? 0:1) + 1 + encodedLength += ((0x7f - c) >>> 31) + 1; + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + encodedLength++; + // WRITE_UTF_UNKNOWN + continue; + } + final char c2; + try { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. + c2 = seq[++i]; + } catch (IndexOutOfBoundsException ignored) { + encodedLength++; + // WRITE_UTF_UNKNOWN + break; + } + if (!Character.isLowSurrogate(c2)) { + // WRITE_UTF_UNKNOWN + (Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2) + encodedLength += 2; + continue; + } + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + encodedLength += 4; + } else { + encodedLength += 3; + } + } + return encodedLength; + } + + private static char[] checkCharSequenceBounds(char[] seq, int start, int end) { + if (MathUtil.isOutOfBounds(start, end - start, seq.length)) { + throw new IndexOutOfBoundsException( + "expected: 0 <= start(" + + start + + ") <= end (" + + end + + ") <= seq.length(" + + seq.length + + ')'); + } + return seq; + } + + /** + * Encode a {@code char[]} in UTF-8 and write it + * into {@link ByteBuf}. + * + *

This method returns the actual number of bytes written. + */ + public static int writeUtf8(ByteBuf buf, char[] seq) { + return writeUtf8(buf, seq, 0, seq.length); + } + + /** + * Equivalent to {@link #writeUtf8(ByteBuf, char[]) writeUtf8(buf, seq.subSequence(start, end), + * reserveBytes)} but avoids subsequence object allocation if possible. + * + * @return actual number of bytes written + */ + public static int writeUtf8(ByteBuf buf, char[] seq, int start, int end) { + return writeUtf8(buf, buf.writerIndex(), checkCharSequenceBounds(seq, start, end), start, end); + } + + // Fast-Path implementation + static int writeUtf8(ByteBuf buffer, int writerIndex, char[] seq, int start, int end) { + int oldWriterIndex = writerIndex; + + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = start; i < end; i++) { + char c = seq[i]; + if (c < 0x80) { + buffer.setByte(writerIndex++, (byte) c); + } else if (c < 0x800) { + buffer.setByte(writerIndex++, (byte) (0xc0 | (c >> 6))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + continue; + } + final char c2; + if (seq.length > ++i) { + // Surrogate Pair consumes 2 characters. Optimistically try to get the next character to + // avoid + // duplicate bounds checking with charAt. If an IndexOutOfBoundsException is thrown we + // will + // re-throw a more informative exception describing the problem. + c2 = seq[i]; + } else { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + break; + } + // Extra method to allow inlining the rest of writeUtf8 which is the most likely code path. + writerIndex = writeUtf8Surrogate(buffer, writerIndex, c, c2); + } else { + buffer.setByte(writerIndex++, (byte) (0xe0 | (c >> 12))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((c >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } + } + buffer.writerIndex(writerIndex); + return writerIndex - oldWriterIndex; + } + + private static int writeUtf8Surrogate(ByteBuf buffer, int writerIndex, char c, char c2) { + if (!Character.isLowSurrogate(c2)) { + buffer.setByte(writerIndex++, WRITE_UTF_UNKNOWN); + buffer.setByte(writerIndex++, Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2); + return writerIndex; + } + int codePoint = Character.toCodePoint(c, c2); + // See http://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer.setByte(writerIndex++, (byte) (0xf0 | (codePoint >> 18))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); + buffer.setByte(writerIndex++, (byte) (0x80 | (codePoint & 0x3f))); + return writerIndex; + } + + public static char[] readUtf8(ByteBuf byteBuf, int length) { + CharsetDecoder charsetDecoder = CharsetUtil.UTF_8.newDecoder(); + int en = (int) (length * (double) charsetDecoder.maxCharsPerByte()); + char[] ca = new char[en]; + + CharBuffer charBuffer = CharBuffer.wrap(ca); + ByteBuffer byteBuffer = + byteBuf.nioBufferCount() == 1 + ? byteBuf.internalNioBuffer(byteBuf.readerIndex(), length) + : byteBuf.nioBuffer(byteBuf.readerIndex(), length); + byteBuffer.mark(); + try { + CoderResult cr = charsetDecoder.decode(byteBuffer, charBuffer, true); + if (!cr.isUnderflow()) cr.throwException(); + cr = charsetDecoder.flush(charBuffer); + if (!cr.isUnderflow()) cr.throwException(); + + byteBuffer.reset(); + byteBuf.skipBytes(length); + + return safeTrim(charBuffer.array(), charBuffer.position()); + } catch (CharacterCodingException x) { + // Substitution is always enabled, + // so this shouldn't happen + throw new IllegalStateException("unable to decode char array from the given buffer", x); + } + } + + private static char[] safeTrim(char[] ca, int len) { + if (len == ca.length) return ca; + else return Arrays.copyOf(ca, len); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/Clock.java b/rsocket-core/src/main/java/io/rsocket/util/Clock.java new file mode 100644 index 000000000..4a34c988f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/Clock.java @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import java.util.concurrent.TimeUnit; + +/** Abstraction to get current time and durations. */ +public final class Clock { + + private Clock() { + // No Instances. + } + + public static long now() { + return System.nanoTime() / 1000; + } + + public static long elapsedSince(long timestamp) { + long t = now(); + return Math.max(0L, t - timestamp); + } + + public static TimeUnit unit() { + return TimeUnit.MICROSECONDS; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java new file mode 100644 index 000000000..08b8b2fb7 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/DefaultPayload.java @@ -0,0 +1,194 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import reactor.util.annotation.Nullable; + +/** + * An implementation of {@link Payload}. This implementation is not thread-safe, and hence + * any method can not be invoked concurrently. + */ +public final class DefaultPayload implements Payload { + public static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocateDirect(0); + + private final ByteBuffer data; + private final ByteBuffer metadata; + + private DefaultPayload(ByteBuffer data, @Nullable ByteBuffer metadata) { + this.data = data; + this.metadata = metadata; + } + + /** + * Static factory method for a text payload. Mainly looks better than "new DefaultPayload(data)" + * + * @param data the data of the payload. + * @return a payload. + */ + public static Payload create(CharSequence data) { + return create(StandardCharsets.UTF_8.encode(CharBuffer.wrap(data)), null); + } + + /** + * Static factory method for a text payload. Mainly looks better than "new DefaultPayload(data, + * metadata)" + * + * @param data the data of the payload. + * @param metadata the metadata for the payload. + * @return a payload. + */ + public static Payload create(CharSequence data, @Nullable CharSequence metadata) { + return create( + StandardCharsets.UTF_8.encode(CharBuffer.wrap(data)), + metadata == null ? null : StandardCharsets.UTF_8.encode(CharBuffer.wrap(metadata))); + } + + public static Payload create(CharSequence data, Charset dataCharset) { + return create(dataCharset.encode(CharBuffer.wrap(data)), null); + } + + public static Payload create( + CharSequence data, + Charset dataCharset, + @Nullable CharSequence metadata, + Charset metadataCharset) { + return create( + dataCharset.encode(CharBuffer.wrap(data)), + metadata == null ? null : metadataCharset.encode(CharBuffer.wrap(metadata))); + } + + public static Payload create(byte[] data) { + return create(ByteBuffer.wrap(data), null); + } + + public static Payload create(byte[] data, @Nullable byte[] metadata) { + return create(ByteBuffer.wrap(data), metadata == null ? null : ByteBuffer.wrap(metadata)); + } + + public static Payload create(ByteBuffer data) { + return create(data, null); + } + + public static Payload create(ByteBuffer data, @Nullable ByteBuffer metadata) { + return new DefaultPayload(data, metadata); + } + + public static Payload create(ByteBuf data) { + return create(data, null); + } + + public static Payload create(ByteBuf data, @Nullable ByteBuf metadata) { + try { + return create(toBytes(data), metadata != null ? toBytes(metadata) : null); + } finally { + data.release(); + if (metadata != null) { + metadata.release(); + } + } + } + + public static Payload create(Payload payload) { + return create( + toBytes(payload.data()), payload.hasMetadata() ? toBytes(payload.metadata()) : null); + } + + private static byte[] toBytes(ByteBuf byteBuf) { + byte[] bytes = new byte[byteBuf.readableBytes()]; + byteBuf.markReaderIndex(); + byteBuf.readBytes(bytes); + byteBuf.resetReaderIndex(); + return bytes; + } + + @Override + public boolean hasMetadata() { + return metadata != null; + } + + @Override + public ByteBuf sliceMetadata() { + return metadata == null ? Unpooled.EMPTY_BUFFER : Unpooled.wrappedBuffer(metadata); + } + + @Override + public ByteBuf sliceData() { + return Unpooled.wrappedBuffer(data); + } + + @Override + public ByteBuffer getMetadata() { + return metadata == null ? DefaultPayload.EMPTY_BUFFER : metadata.duplicate(); + } + + @Override + public ByteBuffer getData() { + return data.duplicate(); + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public DefaultPayload retain() { + return this; + } + + @Override + public DefaultPayload retain(int increment) { + return this; + } + + @Override + public DefaultPayload touch() { + return this; + } + + @Override + public DefaultPayload touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java b/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java new file mode 100644 index 000000000..99df97d70 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/EmptyPayload.java @@ -0,0 +1,87 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; + +public class EmptyPayload implements Payload { + public static final EmptyPayload INSTANCE = new EmptyPayload(); + + private EmptyPayload() {} + + @Override + public boolean hasMetadata() { + return false; + } + + @Override + public ByteBuf sliceMetadata() { + return Unpooled.EMPTY_BUFFER; + } + + @Override + public ByteBuf sliceData() { + return Unpooled.EMPTY_BUFFER; + } + + @Override + public ByteBuf data() { + return sliceData(); + } + + @Override + public ByteBuf metadata() { + return sliceMetadata(); + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public EmptyPayload retain() { + return this; + } + + @Override + public EmptyPayload retain(int increment) { + return this; + } + + @Override + public EmptyPayload touch() { + return this; + } + + @Override + public EmptyPayload touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java new file mode 100644 index 000000000..3ff720447 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/NumberUtils.java @@ -0,0 +1,164 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import io.netty.buffer.ByteBuf; +import java.util.Objects; + +public final class NumberUtils { + + /** The size of a medium in {@code byte}s. */ + public static final int MEDIUM_BYTES = 3; + + private static final int UNSIGNED_BYTE_SIZE = 8; + + private static final int UNSIGNED_BYTE_MAX_VALUE = (1 << UNSIGNED_BYTE_SIZE) - 1; + + private static final int UNSIGNED_MEDIUM_SIZE = 24; + + private static final int UNSIGNED_MEDIUM_MAX_VALUE = (1 << UNSIGNED_MEDIUM_SIZE) - 1; + + private static final int UNSIGNED_SHORT_SIZE = 16; + + private static final int UNSIGNED_SHORT_MAX_VALUE = (1 << UNSIGNED_SHORT_SIZE) - 1; + + private NumberUtils() {} + + /** + * Requires that an {@code int} is greater than or equal to zero. + * + * @param i the {@code int} to test + * @param message detail message to be used in the event that a {@link IllegalArgumentException} + * is thrown + * @return the {@code int} if greater than or equal to zero + * @throws IllegalArgumentException if {@code i} is less than zero + */ + public static int requireNonNegative(int i, String message) { + Objects.requireNonNull(message, "message must not be null"); + + if (i < 0) { + throw new IllegalArgumentException(message); + } + + return i; + } + + /** + * Requires that a {@code long} is greater than zero. + * + * @param l the {@code long} to test + * @param message detail message to be used in the event that a {@link IllegalArgumentException} + * is thrown + * @return the {@code long} if greater than zero + * @throws IllegalArgumentException if {@code l} is less than or equal to zero + */ + public static long requirePositive(long l, String message) { + Objects.requireNonNull(message, "message must not be null"); + + if (l <= 0) { + throw new IllegalArgumentException(message); + } + + return l; + } + + /** + * Requires that an {@code int} is greater than zero. + * + * @param i the {@code int} to test + * @param message detail message to be used in the event that a {@link IllegalArgumentException} + * is thrown + * @return the {@code int} if greater than zero + * @throws IllegalArgumentException if {@code i} is less than or equal to zero + */ + public static int requirePositive(int i, String message) { + Objects.requireNonNull(message, "message must not be null"); + + if (i <= 0) { + throw new IllegalArgumentException(message); + } + + return i; + } + + /** + * Requires that an {@code int} can be represented as an unsigned {@code byte}. + * + * @param i the {@code int} to test + * @return the {@code int} if it can be represented as an unsigned {@code byte} + * @throws IllegalArgumentException if {@code i} cannot be represented as an unsigned {@code byte} + */ + public static int requireUnsignedByte(int i) { + if (i > UNSIGNED_BYTE_MAX_VALUE) { + throw new IllegalArgumentException( + String.format("%d is larger than %d bits", i, UNSIGNED_BYTE_SIZE)); + } + + return i; + } + + /** + * Requires that an {@code int} can be represented as an unsigned {@code medium}. + * + * @param i the {@code int} to test + * @return the {@code int} if it can be represented as an unsigned {@code medium} + * @throws IllegalArgumentException if {@code i} cannot be represented as an unsigned {@code + * medium} + */ + public static int requireUnsignedMedium(int i) { + if (i > UNSIGNED_MEDIUM_MAX_VALUE) { + throw new IllegalArgumentException( + String.format("%d is larger than %d bits", i, UNSIGNED_MEDIUM_SIZE)); + } + + return i; + } + + /** + * Requires that an {@code int} can be represented as an unsigned {@code short}. + * + * @param i the {@code int} to test + * @return the {@code int} if it can be represented as an unsigned {@code short} + * @throws IllegalArgumentException if {@code i} cannot be represented as an unsigned {@code + * short} + */ + public static int requireUnsignedShort(int i) { + if (i > UNSIGNED_SHORT_MAX_VALUE) { + throw new IllegalArgumentException( + String.format("%d is larger than %d bits", i, UNSIGNED_SHORT_SIZE)); + } + + return i; + } + + /** + * Encode an unsigned medium integer on 3 bytes / 24 bits. This can be decoded directly by the + * {@link ByteBuf#readUnsignedMedium()} method. + * + * @param byteBuf the {@link ByteBuf} into which to write the bits + * @param i the medium integer to encode + * @see #requireUnsignedMedium(int) + */ + public static void encodeUnsignedMedium(ByteBuf byteBuf, int i) { + requireUnsignedMedium(i); + // Write each byte separately in reverse order, this mean we can write 1 << 23 without + // overflowing. + byteBuf.writeByte(i >> 16); + byteBuf.writeByte(i >> 8); + byteBuf.writeByte(i); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/RSocketProxy.java b/rsocket-core/src/main/java/io/rsocket/util/RSocketProxy.java new file mode 100644 index 000000000..518b727c1 --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/RSocketProxy.java @@ -0,0 +1,77 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** Wrapper/Proxy for a RSocket. This is useful when we want to override a specific method. */ +public class RSocketProxy implements RSocket { + protected final RSocket source; + + public RSocketProxy(RSocket source) { + this.source = source; + } + + @Override + public Mono fireAndForget(Payload payload) { + return source.fireAndForget(payload); + } + + @Override + public Mono requestResponse(Payload payload) { + return source.requestResponse(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return source.requestStream(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return source.requestChannel(payloads); + } + + @Override + public Mono metadataPush(Payload payload) { + return source.metadataPush(payload); + } + + @Override + public double availability() { + return source.availability(); + } + + @Override + public void dispose() { + source.dispose(); + } + + @Override + public boolean isDisposed() { + return source.isDisposed(); + } + + @Override + public Mono onClose() { + return source.onClose(); + } +} diff --git a/rsocket-core/src/main/java/io/rsocket/util/package-info.java b/rsocket-core/src/main/java/io/rsocket/util/package-info.java new file mode 100644 index 000000000..2fac3327f --- /dev/null +++ b/rsocket-core/src/main/java/io/rsocket/util/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Shared utility classes and {@link io.rsocket.Payload} implementations. */ +@NonNullApi +package io.rsocket.util; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json b/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json new file mode 100644 index 000000000..0a3844451 --- /dev/null +++ b/rsocket-core/src/main/resources/META-INF/native-image/io.rsocket/rsocket-core/reflect-config.json @@ -0,0 +1,130 @@ +[ + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseLinkedQueueConsumerNodeRef" + }, + "name": "io.rsocket.internal.jctools.queues.BaseLinkedQueueConsumerNodeRef", + "fields": [ + { + "name": "consumerNode" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseLinkedQueueProducerNodeRef" + }, + "name": "io.rsocket.internal.jctools.queues.BaseLinkedQueueProducerNodeRef", + "fields": [ + { + "name": "producerNode" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueColdProducerFields", + "fields": [ + { + "name": "producerLimit" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueConsumerFields", + "fields": [ + { + "name": "consumerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueProducerFields" + }, + "name": "io.rsocket.internal.jctools.queues.BaseMpscLinkedArrayQueueProducerFields", + "fields": [ + { + "name": "producerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.LinkedQueueNode" + }, + "name": "io.rsocket.internal.jctools.queues.LinkedQueueNode", + "fields": [ + { + "name": "next" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueConsumerIndexField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueConsumerIndexField", + "fields": [ + { + "name": "consumerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerIndexField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerIndexField", + "fields": [ + { + "name": "producerIndex" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerLimitField" + }, + "name": "io.rsocket.internal.jctools.queues.MpscArrayQueueProducerLimitField", + "fields": [ + { + "name": "producerLimit" + } + ] + }, + { + "condition": { + "typeReachable": "io.rsocket.internal.jctools.queues.UnsafeAccess" + }, + "name": "sun.misc.Unsafe", + "fields": [ + { + "name": "theUnsafe" + } + ], + "queriedMethods": [ + { + "name": "getAndAddLong", + "parameterTypes": [ + "java.lang.Object", + "long", + "long" + ] + }, + { + "name": "getAndSetObject", + "parameterTypes": [ + "java.lang.Object", + "long", + "java.lang.Object" + ] + } + ] + } +] \ No newline at end of file diff --git a/rsocket-core/src/test/java/io/rsocket/FrameAssert.java b/rsocket-core/src/test/java/io/rsocket/FrameAssert.java new file mode 100644 index 000000000..b5b1e2ec9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/FrameAssert.java @@ -0,0 +1,336 @@ +package io.rsocket; + +import static org.assertj.core.error.ShouldBe.shouldBe; +import static org.assertj.core.error.ShouldBeEqual.shouldBeEqual; +import static org.assertj.core.error.ShouldHave.shouldHave; +import static org.assertj.core.error.ShouldNotHave.shouldNotHave; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.frame.*; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Condition; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.internal.Failures; +import org.assertj.core.internal.Objects; +import reactor.util.annotation.Nullable; + +public class FrameAssert extends AbstractAssert { + public static FrameAssert assertThat(@Nullable ByteBuf frame) { + return new FrameAssert(frame); + } + + private final Failures failures = Failures.instance(); + + public FrameAssert(@Nullable ByteBuf frame) { + super(frame, FrameAssert.class); + } + + public FrameAssert hasMetadata() { + assertValid(); + + if (!FrameHeaderCodec.hasMetadata(actual)) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata present"))); + } + + return this; + } + + public FrameAssert hasNoMetadata() { + assertValid(); + + if (FrameHeaderCodec.hasMetadata(actual)) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata absent"))); + } + + return this; + } + + public FrameAssert hasMetadata(String metadata, Charset charset) { + return hasMetadata(metadata.getBytes(charset)); + } + + public FrameAssert hasMetadata(String metadataUtf8) { + return hasMetadata(metadataUtf8, CharsetUtil.UTF_8); + } + + public FrameAssert hasMetadata(byte[] metadata) { + return hasMetadata(Unpooled.wrappedBuffer(metadata)); + } + + public FrameAssert hasMetadata(ByteBuf metadata) { + hasMetadata(); + + final FrameType frameType = FrameHeaderCodec.frameType(actual); + ByteBuf content; + if (frameType == FrameType.METADATA_PUSH) { + content = MetadataPushFrameCodec.metadata(actual); + } else if (frameType.hasInitialRequestN()) { + content = RequestStreamFrameCodec.metadata(actual); + } else { + content = PayloadFrameCodec.metadata(actual); + } + + if (!ByteBufUtil.equals(content, metadata)) { + throw failures.failure(info, shouldBeEqual(content, metadata, new ByteBufRepresentation())); + } + + return this; + } + + public FrameAssert hasData(String dataUtf8) { + return hasData(dataUtf8, CharsetUtil.UTF_8); + } + + public FrameAssert hasData(String data, Charset charset) { + return hasData(data.getBytes(charset)); + } + + public FrameAssert hasData(byte[] data) { + return hasData(Unpooled.wrappedBuffer(data)); + } + + public FrameAssert hasData(ByteBuf data) { + assertValid(); + + ByteBuf content; + final FrameType frameType = FrameHeaderCodec.frameType(actual); + if (!frameType.canHaveData()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have data content but frame type %n<%s> does not support data content", + actual, frameType)); + } else if (frameType.hasInitialRequestN()) { + content = RequestStreamFrameCodec.data(actual); + } else if (frameType == FrameType.ERROR) { + content = ErrorFrameCodec.data(actual); + } else { + content = PayloadFrameCodec.data(actual); + } + + if (!ByteBufUtil.equals(content, data)) { + throw failures.failure(info, shouldBeEqual(content, data, new ByteBufRepresentation())); + } + + return this; + } + + public FrameAssert hasFragmentsFollow() { + return hasFollows(true); + } + + public FrameAssert hasNoFragmentsFollow() { + return hasFollows(false); + } + + public FrameAssert hasFollows(boolean hasFollows) { + assertValid(); + + if (FrameHeaderCodec.hasFollows(actual) != hasFollows) { + throw failures.failure( + info, + hasFollows + ? shouldHave(actual, new Condition<>("follows fragment present")) + : shouldNotHave(actual, new Condition<>("follows fragment present"))); + } + + return this; + } + + public FrameAssert typeOf(FrameType frameType) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + if (currentFrameType != frameType) { + throw failures.failure( + info, shouldBe(currentFrameType, new Condition<>("frame of type [" + frameType + "]"))); + } + + return this; + } + + public FrameAssert hasStreamId(int streamId) { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId != streamId) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting streamId:%n<%s>%n to be equal %n<%s>", currentStreamId, streamId)); + } + + return this; + } + + public FrameAssert hasStreamIdZero() { + return hasStreamId(0); + } + + public FrameAssert hasClientSideStreamId() { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId % 2 != 1) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting Client Side StreamId %nbut was " + + (currentStreamId == 0 ? "Stream Id 0" : "Server Side Stream Id"))); + } + + return this; + } + + public FrameAssert hasServerSideStreamId() { + assertValid(); + + final int currentStreamId = FrameHeaderCodec.streamId(actual); + if (currentStreamId == 0 || currentStreamId % 2 != 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting %n Server Side Stream Id %nbut was %n " + + (currentStreamId == 0 ? "Stream Id 0" : "Client Side Stream Id"))); + } + + return this; + } + + public FrameAssert hasPayloadSize(int payloadLength) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + + final int currentFrameLength = + actual.readableBytes() + - FrameHeaderCodec.size() + - (FrameHeaderCodec.hasMetadata(actual) && currentFrameType.canHaveData() ? 3 : 0) + - (currentFrameType.hasInitialRequestN() ? Integer.BYTES : 0); + if (currentFrameLength != payloadLength) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting %n<%s> %nframe payload size to be equal to %n<%s> %nbut was %n<%s>", + actual, payloadLength, currentFrameLength)); + } + + return this; + } + + public FrameAssert hasRequestN(int n) { + assertValid(); + + final FrameType currentFrameType = FrameHeaderCodec.frameType(actual); + long requestN; + if (currentFrameType.hasInitialRequestN()) { + requestN = RequestStreamFrameCodec.initialRequestN(actual); + } else if (currentFrameType == FrameType.REQUEST_N) { + requestN = RequestNFrameCodec.requestN(actual); + } else { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have requestN but frame type %n<%s> does not support requestN", + actual, currentFrameType)); + } + + if ((requestN > Integer.MAX_VALUE ? Integer.MAX_VALUE : requestN) != n) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have %nrequestN(<%s>) but got %nrequestN(<%s>)", + actual, n, requestN)); + } + + return this; + } + + public FrameAssert hasPayload(Payload expectedPayload) { + assertValid(); + + List failedExpectation = new ArrayList<>(); + FrameType frameType = FrameHeaderCodec.frameType(actual); + boolean hasMetadata = FrameHeaderCodec.hasMetadata(actual); + if (expectedPayload.hasMetadata() != hasMetadata) { + failedExpectation.add( + String.format( + "hasMetadata(%s) but actual was hasMetadata(%s)%n", + expectedPayload.hasMetadata(), hasMetadata)); + } else if (hasMetadata) { + ByteBuf metadataContent; + if (frameType == FrameType.METADATA_PUSH) { + metadataContent = MetadataPushFrameCodec.metadata(actual); + } else if (frameType.hasInitialRequestN()) { + metadataContent = RequestStreamFrameCodec.metadata(actual); + } else { + metadataContent = PayloadFrameCodec.metadata(actual); + } + if (!ByteBufUtil.equals(expectedPayload.sliceMetadata(), metadataContent)) { + failedExpectation.add( + String.format( + "metadata(%s) but actual was metadata(%s)%n", + expectedPayload.sliceMetadata(), metadataContent)); + } + } + + ByteBuf dataContent; + if (!frameType.canHaveData() && expectedPayload.sliceData().readableBytes() > 0) { + failedExpectation.add( + String.format( + "data(%s) but frame type %n<%s> does not support data", actual, frameType)); + } else { + if (frameType.hasInitialRequestN()) { + dataContent = RequestStreamFrameCodec.data(actual); + } else { + dataContent = PayloadFrameCodec.data(actual); + } + + if (!ByteBufUtil.equals(expectedPayload.sliceData(), dataContent)) { + failedExpectation.add( + String.format( + "data(%s) but actual was data(%s)%n", expectedPayload.sliceData(), dataContent)); + } + } + + if (!failedExpectation.isEmpty()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting be equal to the given payload but the following differences were found" + + " %s", + failedExpectation)); + } + + return this; + } + + public void hasNoLeaks() { + if (!actual.release() || actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was " + + "%n", + actual, actual.refCnt())); + } + } + + private void assertValid() { + Objects.instance().assertNotNull(info, actual); + + try { + FrameHeaderCodec.frameType(actual); + } catch (Throwable t) { + throw failures.failure( + info, shouldBe(actual, new Condition<>("a valid frame, but got exception [" + t + "]"))); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/FrameTest.java b/rsocket-core/src/test/java/io/rsocket/FrameTest.java new file mode 100644 index 000000000..82af5f53c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/FrameTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket; + +public class FrameTest { + /*@Test + public void testFrameToString() { + final io.rsocket.Frame requestFrame = + io.rsocket.Frame.Request.from( + 1, FrameType.REQUEST_RESPONSE, DefaultPayload.create("streaming in -> 0"), 1); + assertEquals( + "Frame => Stream ID: 1 Type: REQUEST_RESPONSE Payload: data: \"streaming in -> 0\" ", + requestFrame.toString()); + } + + @Test + public void testFrameWithMetadataToString() { + final io.rsocket.Frame requestFrame = + io.rsocket.Frame.Request.from( + 1, + FrameType.REQUEST_RESPONSE, + DefaultPayload.create("streaming in -> 0", "metadata"), + 1); + assertEquals( + "Frame => Stream ID: 1 Type: REQUEST_RESPONSE Payload: metadata: \"metadata\" data: \"streaming in -> 0\" ", + requestFrame.toString()); + } + + @Test + public void testPayload() { + io.rsocket.Frame frame = + io.rsocket.Frame.PayloadFrame.from( + 1, + FrameType.NEXT_COMPLETE, + DefaultPayload.create("Hello"), + FrameHeaderFlyweight.FLAGS_C); + frame.toString(); + }*/ +} diff --git a/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java b/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java new file mode 100755 index 000000000..847f24722 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/PayloadAssert.java @@ -0,0 +1,180 @@ +package io.rsocket; + +import static org.assertj.core.error.ShouldBeEqual.shouldBeEqual; +import static org.assertj.core.error.ShouldHave.shouldHave; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.frame.ByteBufRepresentation; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.api.Condition; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.internal.Failures; +import org.assertj.core.internal.Objects; +import reactor.util.annotation.Nullable; + +public class PayloadAssert extends AbstractAssert { + + public static PayloadAssert assertThat(@Nullable Payload payload) { + return new PayloadAssert(payload); + } + + private final Failures failures = Failures.instance(); + + public PayloadAssert(@Nullable Payload payload) { + super(payload, PayloadAssert.class); + } + + public PayloadAssert hasMetadata() { + assertValid(); + + if (!actual.hasMetadata()) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata present"))); + } + + return this; + } + + public PayloadAssert hasNoMetadata() { + assertValid(); + + if (actual.hasMetadata()) { + throw failures.failure(info, shouldHave(actual, new Condition<>("metadata absent"))); + } + + return this; + } + + public PayloadAssert hasMetadata(String metadata, Charset charset) { + return hasMetadata(metadata.getBytes(charset)); + } + + public PayloadAssert hasMetadata(String metadataUtf8) { + return hasMetadata(metadataUtf8, CharsetUtil.UTF_8); + } + + public PayloadAssert hasMetadata(byte[] metadata) { + return hasMetadata(Unpooled.wrappedBuffer(metadata)); + } + + public PayloadAssert hasMetadata(ByteBuf metadata) { + hasMetadata(); + + ByteBuf content = actual.sliceMetadata(); + if (!ByteBufUtil.equals(content, metadata)) { + throw failures.failure(info, shouldBeEqual(content, metadata, new ByteBufRepresentation())); + } + + return this; + } + + public PayloadAssert hasData(String dataUtf8) { + return hasData(dataUtf8, CharsetUtil.UTF_8); + } + + public PayloadAssert hasData(String data, Charset charset) { + return hasData(data.getBytes(charset)); + } + + public PayloadAssert hasData(byte[] data) { + return hasData(Unpooled.wrappedBuffer(data)); + } + + public PayloadAssert hasData(ByteBuf data) { + assertValid(); + + ByteBuf content = actual.sliceData(); + if (!ByteBufUtil.equals(content, data)) { + throw failures.failure(info, shouldBeEqual(content, data, new ByteBufRepresentation())); + } + + return this; + } + + public void hasNoLeaks() { + if (!(actual instanceof DefaultPayload)) { + if (actual.refCnt() == 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was already released", + actual, actual.refCnt())); + } + if (!actual.release() || actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) after release but " + + "actual was " + + "%n", + actual, actual.refCnt())); + } + } + } + + public void isReleased() { + if (actual.refCnt() > 0) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting: %n<%s> %nto have refCnt(0) but " + "actual was " + "%n", + actual, actual.refCnt())); + } + } + + @Override + public PayloadAssert isEqualTo(Object expected) { + if (expected instanceof Payload) { + if (expected == actual) { + return this; + } + + Payload expectedPayload = (Payload) expected; + List failedExpectation = new ArrayList<>(); + if (expectedPayload.hasMetadata() != actual.hasMetadata()) { + failedExpectation.add( + String.format( + "hasMetadata(%s) but actual was hasMetadata(%s)%n", + expectedPayload.hasMetadata(), actual.hasMetadata())); + } else { + if (!ByteBufUtil.equals(expectedPayload.sliceMetadata(), actual.sliceMetadata())) { + failedExpectation.add( + String.format( + "metadata(%s) but actual was metadata(%s)%n", + expectedPayload.sliceMetadata(), actual.sliceMetadata())); + } + } + + if (!ByteBufUtil.equals(expectedPayload.sliceData(), actual.sliceData())) { + failedExpectation.add( + String.format( + "data(%s) but actual was data(%s)%n", + expectedPayload.sliceData(), actual.sliceData())); + } + + if (!failedExpectation.isEmpty()) { + throw failures.failure( + info, + new BasicErrorMessageFactory( + "%nExpecting be equal to the given one but the following differences were found" + + " %s", + failedExpectation)); + } + + return this; + } + + return super.isEqualTo(expected); + } + + private void assertValid() { + Objects.instance().assertNotNull(info, actual); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java b/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java new file mode 100644 index 000000000..d30f1415e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/RaceTestConstants.java @@ -0,0 +1,6 @@ +package io.rsocket; + +public class RaceTestConstants { + public static final int REPEATS = + Integer.parseInt(System.getProperty("rsocket.test.race.repeats", "1000")); +} diff --git a/rsocket-core/src/test/java/io/rsocket/TestScheduler.java b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java new file mode 100644 index 000000000..7bc98d45d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/TestScheduler.java @@ -0,0 +1,80 @@ +package io.rsocket; + +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Exceptions; +import reactor.core.scheduler.Scheduler; +import reactor.util.concurrent.Queues; + +/** + * This is an implementation of scheduler which allows task execution on the caller thread or + * scheduling it for thread which are currently working (with "work stealing" behaviour) + */ +public final class TestScheduler implements Scheduler { + + public static final Scheduler INSTANCE = new TestScheduler(); + + volatile int wip; + static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(TestScheduler.class, "wip"); + + final Worker sharedWorker = new TestWorker(this); + final Queue tasks = Queues.unboundedMultiproducer().get(); + + private TestScheduler() {} + + @Override + public Disposable schedule(Runnable task) { + tasks.offer(task); + if (WIP.getAndIncrement(this) != 0) { + return Disposables.never(); + } + + int missed = 1; + + for (; ; ) { + for (; ; ) { + Runnable runnable = tasks.poll(); + + if (runnable == null) { + break; + } + + try { + runnable.run(); + } catch (Throwable t) { + Exceptions.throwIfFatal(t); + } + } + + missed = WIP.addAndGet(this, -missed); + if (missed == 0) { + return Disposables.never(); + } + } + } + + @Override + public Worker createWorker() { + return sharedWorker; + } + + static class TestWorker implements Worker { + + final TestScheduler parent; + + TestWorker(TestScheduler parent) { + this.parent = parent; + } + + @Override + public Disposable schedule(Runnable task) { + return parent.schedule(task); + } + + @Override + public void dispose() {} + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..1db708ab5 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/buffer/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,294 @@ +package io.rsocket.buffer; + +import static java.util.concurrent.locks.LockSupport.parkNanos; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + static final Logger LOGGER = LoggerFactory.getLogger(LeaksTrackingByteBufAllocator.class); + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO, ""); + } + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument( + ByteBufAllocator allocator, Duration awaitZeroRefCntDuration, String tag) { + return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration, tag); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + final Duration awaitZeroRefCntDuration; + + final String tag; + + private LeaksTrackingByteBufAllocator( + ByteBufAllocator delegate, Duration awaitZeroRefCntDuration, String tag) { + this.delegate = delegate; + this.awaitZeroRefCntDuration = awaitZeroRefCntDuration; + this.tag = tag; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + ArrayList unreleased = new ArrayList<>(); + for (ByteBuf bb : tracker) { + if (bb.refCnt() != 0) { + unreleased.add(bb); + } + } + + final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; + if (!unreleased.isEmpty() && !awaitZeroRefCntDuration.isZero()) { + final long startTime = System.currentTimeMillis(); + final long endTimeInMillis = startTime + awaitZeroRefCntDuration.toMillis(); + boolean hasUnreleased; + while (System.currentTimeMillis() <= endTimeInMillis) { + hasUnreleased = false; + for (ByteBuf bb : unreleased) { + if (bb.refCnt() != 0) { + hasUnreleased = true; + break; + } + } + + if (!hasUnreleased) { + return this; + } + + LOGGER.debug(tag + " await buffers to be released"); + for (int i = 0; i < 100; i++) { + System.gc(); + parkNanos(1000); + System.gc(); + } + } + } + + Set collected = new HashSet<>(); + for (ByteBuf buf : unreleased) { + if (buf.refCnt() != 0) { + try { + collected.add(buf); + } catch (IllegalReferenceCountException ignored) { + // fine to ignore if throws because of refCnt + } + } + } + + Assertions.assertThat( + collected + .stream() + .filter(bb -> bb.refCnt() != 0) + .peek( + bb -> { + try { + LOGGER.debug(tag + " " + resolveTrackingInfo(bb)); + } catch (Exception e) { + e.printStackTrace(); + } + })) + .describedAs("[" + tag + "] all buffers expected to be released but got ") + .isEmpty(); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } + + static final Class simpleLeakAwareCompositeByteBufClass; + static final Field leakFieldForComposite; + static final Class simpleLeakAwareByteBufClass; + static final Field leakFieldForNormal; + static final Field allLeaksField; + + static { + try { + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareCompositeByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareCompositeByteBufClass = aClass; + leakFieldForComposite = leakField; + } + + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareByteBufClass = aClass; + leakFieldForNormal = leakField; + } + + { + final Class aClass = + Class.forName("io.netty.util.ResourceLeakDetector$DefaultResourceLeak"); + final Field field = aClass.getDeclaredField("allLeaks"); + + field.setAccessible(true); + + allLeaksField = field; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + static Set resolveTrackingInfo(ByteBuf byteBuf) throws Exception { + if (ResourceLeakDetector.getLevel().ordinal() + >= ResourceLeakDetector.Level.ADVANCED.ordinal()) { + if (simpleLeakAwareCompositeByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForComposite.get(byteBuf)); + } else if (simpleLeakAwareByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForNormal.get(byteBuf)); + } + } + + return Collections.emptySet(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java new file mode 100644 index 000000000..310e15b3e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/AbstractSocketRule.java @@ -0,0 +1,76 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestSubscriber; +import java.time.Duration; +import org.reactivestreams.Subscriber; + +public abstract class AbstractSocketRule { + + protected TestDuplexConnection connection; + protected Subscriber connectSub; + protected T socket; + protected LeaksTrackingByteBufAllocator allocator; + protected int maxFrameLength = FRAME_LENGTH_MASK; + protected int maxInboundPayloadSize = Integer.MAX_VALUE; + + public void init() { + allocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(5), ""); + connectSub = TestSubscriber.create(); + doInit(); + } + + protected void doInit() { + if (connection != null) { + connection.dispose(); + } + if (socket != null) { + socket.dispose(); + } + connection = new TestDuplexConnection(allocator); + socket = newRSocket(); + } + + public void setMaxInboundPayloadSize(int maxInboundPayloadSize) { + this.maxInboundPayloadSize = maxInboundPayloadSize; + doInit(); + } + + public void setMaxFrameLength(int maxFrameLength) { + this.maxFrameLength = maxFrameLength; + doInit(); + } + + protected abstract T newRSocket(); + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + public void assertHasNoLeaks() { + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java b/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java new file mode 100644 index 000000000..195df9434 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ClientServerInputMultiplexerTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.MetadataPushFrameCodec; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.test.util.TestDuplexConnection; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class ClientServerInputMultiplexerTest { + private TestDuplexConnection source; + private ClientServerInputMultiplexer clientMultiplexer; + private LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private ClientServerInputMultiplexer serverMultiplexer; + + @BeforeEach + public void setup() { + source = new TestDuplexConnection(allocator); + clientMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), true); + serverMultiplexer = + new ClientServerInputMultiplexer(source, new InitializingInterceptorRegistry(), false); + } + + @Test + public void clientSplits() { + AtomicInteger clientFrames = new AtomicInteger(); + AtomicInteger serverFrames = new AtomicInteger(); + + clientMultiplexer + .asClientConnection() + .receive() + .doOnNext( + f -> { + clientFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + clientMultiplexer + .asServerConnection() + .receive() + .doOnNext( + f -> { + serverFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isOne(); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(leaseFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(3); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(keepAliveFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(4); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(2).retain()); + assertThat(clientFrames.get()).isEqualTo(4); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(errorFrame(0).retain()); + assertThat(clientFrames.get()).isEqualTo(5); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(metadataPushFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(5); + assertThat(serverFrames.get()).isEqualTo(2); + } + + @Test + public void serverSplits() { + AtomicInteger clientFrames = new AtomicInteger(); + AtomicInteger serverFrames = new AtomicInteger(); + + serverMultiplexer + .asClientConnection() + .receive() + .doOnNext( + f -> { + clientFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + serverMultiplexer + .asServerConnection() + .receive() + .doOnNext( + f -> { + serverFrames.incrementAndGet(); + f.release(); + }) + .subscribe(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(1); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(errorFrame(1).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isZero(); + + source.addToReceivedBuffer(leaseFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isOne(); + + source.addToReceivedBuffer(keepAliveFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(2); + + source.addToReceivedBuffer(errorFrame(2).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(3); + + source.addToReceivedBuffer(errorFrame(0).retain()); + assertThat(clientFrames.get()).isEqualTo(2); + assertThat(serverFrames.get()).isEqualTo(4); + + source.addToReceivedBuffer(metadataPushFrame().retain()); + assertThat(clientFrames.get()).isEqualTo(3); + assertThat(serverFrames.get()).isEqualTo(4); + } + + private ByteBuf leaseFrame() { + return LeaseFrameCodec.encode(allocator, 1_000, 1, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf errorFrame(int i) { + return ErrorFrameCodec.encode(allocator, i, new Exception()); + } + + private ByteBuf keepAliveFrame() { + return KeepAliveFrameCodec.encode(allocator, false, 0, Unpooled.EMPTY_BUFFER); + } + + private ByteBuf metadataPushFrame() { + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java new file mode 100644 index 000000000..8eb5dee09 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ConnectionSetupPayloadTest.java @@ -0,0 +1,90 @@ +package io.rsocket.core; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.util.DefaultPayload; +import org.junit.jupiter.api.Test; + +class ConnectionSetupPayloadTest { + private static final int KEEP_ALIVE_INTERVAL = 5; + private static final int KEEP_ALIVE_MAX_LIFETIME = 500; + private static final String METADATA_TYPE = "metadata_type"; + private static final String DATA_TYPE = "data_type"; + + @Test + void testSetupPayloadWithDataMetadata() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {2, 1, 0}); + Payload payload = DefaultPayload.create(data, metadata); + boolean leaseEnabled = true; + + ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); + + assertTrue(setupPayload.willClientHonorLease()); + assertEquals(KEEP_ALIVE_INTERVAL, setupPayload.keepAliveInterval()); + assertEquals(KEEP_ALIVE_MAX_LIFETIME, setupPayload.keepAliveMaxLifetime()); + assertEquals(METADATA_TYPE, SetupFrameCodec.metadataMimeType(frame)); + assertEquals(DATA_TYPE, SetupFrameCodec.dataMimeType(frame)); + assertTrue(setupPayload.hasMetadata()); + assertNotNull(setupPayload.metadata()); + assertEquals(payload.metadata(), setupPayload.metadata()); + assertEquals(payload.data(), setupPayload.data()); + frame.release(); + } + + @Test + void testSetupPayloadWithNoMetadata() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf metadata = null; + Payload payload = DefaultPayload.create(data, metadata); + boolean leaseEnabled = false; + + ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); + + assertFalse(setupPayload.willClientHonorLease()); + assertFalse(setupPayload.hasMetadata()); + assertNotNull(setupPayload.metadata()); + assertEquals(0, setupPayload.metadata().readableBytes()); + assertEquals(payload.data(), setupPayload.data()); + frame.release(); + } + + @Test + void testSetupPayloadWithEmptyMetadata() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf metadata = Unpooled.EMPTY_BUFFER; + Payload payload = DefaultPayload.create(data, metadata); + boolean leaseEnabled = false; + + ByteBuf frame = encodeSetupFrame(leaseEnabled, payload); + ConnectionSetupPayload setupPayload = new DefaultConnectionSetupPayload(frame); + + assertFalse(setupPayload.willClientHonorLease()); + assertTrue(setupPayload.hasMetadata()); + assertNotNull(setupPayload.metadata()); + assertEquals(0, setupPayload.metadata().readableBytes()); + assertEquals(payload.data(), setupPayload.data()); + frame.release(); + } + + private static ByteBuf encodeSetupFrame(boolean leaseEnabled, Payload setupPayload) { + return SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + leaseEnabled, + KEEP_ALIVE_INTERVAL, + KEEP_ALIVE_MAX_LIFETIME, + Unpooled.EMPTY_BUFFER, + METADATA_TYPE, + DATA_TYPE, + setupPayload); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java new file mode 100644 index 000000000..84576e6ce --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/DefaultRSocketClientTests.java @@ -0,0 +1,760 @@ +package io.rsocket.core; +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.RSocketProxy; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Map; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.context.Context; +import reactor.util.context.ContextView; +import reactor.util.retry.Retry; + +public class DefaultRSocketClientTests { + + ClientSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped((t) -> {}); + rule = new ClientSocketRule(); + rule.init(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + rule.allocator.assertHasNoLeaks(); + } + + @Test + @SuppressWarnings("unchecked") + void discardElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = Mockito.mock(ReferenceCounted.class); + Mockito.when(referenceCounted.release()).thenReturn(true); + Mockito.when(referenceCounted.refCnt()).thenReturn(1); + + Consumer discardElementsConsumer = DefaultRSocketClient.DISCARD_ELEMENTS_CONSUMER; + discardElementsConsumer.accept(referenceCounted); + + Mockito.verify(referenceCounted).release(); + } + + static Stream interactions() { + return Stream.of( + Arguments.of( + (BiFunction, Publisher>) + (client, payload) -> client.fireAndForget(Mono.fromDirect(payload)), + FrameType.REQUEST_FNF), + Arguments.of( + (BiFunction, Publisher>) + (client, payload) -> client.requestResponse(Mono.fromDirect(payload)), + FrameType.REQUEST_RESPONSE), + Arguments.of( + (BiFunction, Publisher>) + (client, payload) -> client.requestStream(Mono.fromDirect(payload)), + FrameType.REQUEST_STREAM), + Arguments.of( + (BiFunction, Publisher>) + RSocketClient::requestChannel, + FrameType.REQUEST_CHANNEL), + Arguments.of( + (BiFunction, Publisher>) + (client, payload) -> client.metadataPush(Mono.fromDirect(payload)), + FrameType.METADATA_PUSH)); + } + + @ParameterizedTest + @MethodSource("interactions") + public void shouldSentFrameOnResolution( + BiFunction, Publisher> request, FrameType requestType) { + Payload payload = ByteBufPayload.create("test", "testMetadata"); + TestPublisher testPublisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + Publisher publisher = request.apply(rule.client, testPublisher); + + StepVerifier.create(publisher) + .expectSubscription() + .then(() -> Assertions.assertThat(rule.connection.getSent()).isEmpty()) + .then( + () -> { + if (requestType != FrameType.REQUEST_CHANNEL) { + testPublisher.next(payload); + } + }) + .then(() -> rule.delayer.run()) + .then( + () -> { + if (requestType == FrameType.REQUEST_CHANNEL) { + testPublisher.next(payload); + } + }) + .then(testPublisher::complete) + .then( + () -> { + if (requestType == FrameType.REQUEST_CHANNEL) { + Assertions.assertThat(rule.connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + + Assertions.assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.COMPLETE)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + } + }) + .then( + () -> { + if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeComplete(rule.allocator, 1)); + } + }) + .expectComplete() + .verify(Duration.ofMillis(1000)); + + rule.allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"unchecked", "rawtypes"}) + public void shouldHaveNoLeaksOnPayloadInCaseOfRacingOfOnNextAndCancel( + BiFunction, Publisher> request, FrameType requestType) { + Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + rule.init(); + Payload payload = ByteBufPayload.create("test", "testMetadata"); + TestPublisher testPublisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + Publisher publisher = request.apply(rule.client, testPublisher); + publisher.subscribe(assertSubscriber); + + testPublisher.assertWasNotRequested(); + + assertSubscriber.request(1); + + testPublisher.assertWasRequested(); + testPublisher.assertMaxRequested(1); + testPublisher.assertMinRequested(1); + + RaceTestUtils.race( + () -> { + testPublisher.next(payload); + rule.delayer.run(); + }, + assertSubscriber::cancel); + + Collection sent = rule.connection.getSent(); + if (sent.size() == 1) { + Assertions.assertThat(sent) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } else if (sent.size() == 2) { + Assertions.assertThat(sent) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + Assertions.assertThat(sent) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.CANCEL)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent).isEmpty(); + } + + rule.allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"unchecked", "rawtypes"}) + public void shouldHaveNoLeaksOnPayloadInCaseOfRacingOfRequestAndCancel( + BiFunction, Publisher> request, FrameType requestType) { + Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + rule.init(); + ByteBuf dataBuffer = rule.allocator.buffer(); + dataBuffer.writeCharSequence("test", CharsetUtil.UTF_8); + + ByteBuf metadataBuffer = rule.allocator.buffer(); + metadataBuffer.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + + Payload payload = ByteBufPayload.create(dataBuffer, metadataBuffer); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + Publisher publisher = request.apply(rule.client, Mono.just(payload)); + publisher.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + assertSubscriber.request(1); + rule.delayer.run(); + }, + assertSubscriber::cancel); + + Collection sent = rule.connection.getSent(); + if (sent.size() == 1) { + Assertions.assertThat(sent) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } else if (sent.size() == 2) { + Assertions.assertThat(sent) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + Assertions.assertThat(sent) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.CANCEL)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent).isEmpty(); + } + + rule.allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"unchecked", "rawtypes"}) + public void shouldPropagateDownstreamContext( + BiFunction, Publisher> request, FrameType requestType) { + Assumptions.assumeThat(requestType).isNotEqualTo(FrameType.REQUEST_CHANNEL); + + ByteBuf dataBuffer = rule.allocator.buffer(); + dataBuffer.writeCharSequence("test", CharsetUtil.UTF_8); + + ByteBuf metadataBuffer = rule.allocator.buffer(); + metadataBuffer.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + + Payload payload = ByteBufPayload.create(dataBuffer, metadataBuffer); + AssertSubscriber assertSubscriber = new AssertSubscriber(Context.of("test", "test")); + + ContextView[] receivedContext = new Context[1]; + Publisher publisher = + request.apply( + rule.client, + Mono.just(payload) + .mergeWith( + Mono.deferContextual( + c -> { + receivedContext[0] = c; + return Mono.empty(); + }) + .then(Mono.empty()))); + publisher.subscribe(assertSubscriber); + + rule.delayer.run(); + + Collection sent = rule.connection.getSent(); + if (sent.size() == 1) { + Assertions.assertThat(sent) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } else if (sent.size() == 2) { + Assertions.assertThat(sent) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + Assertions.assertThat(sent) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.CANCEL)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent).isEmpty(); + } + + Assertions.assertThat(receivedContext) + .hasSize(1) + .allSatisfy( + c -> + Assertions.assertThat( + c.stream() + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue))) + .containsKeys("test", DefaultRSocketClient.ON_DISCARD_KEY)); + + rule.allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"unchecked", "rawtypes"}) + public void shouldSupportMultiSubscriptionOnTheSameInteractionPublisher( + BiFunction, Publisher> request, FrameType requestType) { + AtomicBoolean once1 = new AtomicBoolean(); + AtomicBoolean once2 = new AtomicBoolean(); + Mono source = + Mono.fromCallable( + () -> { + if (!once1.getAndSet(true)) { + throw new IllegalStateException("test"); + } + return ByteBufPayload.create("test", "testMetadata"); + }) + .doFinally( + st -> { + rule.delayer.run(); + if (requestType != FrameType.METADATA_PUSH + && requestType != FrameType.REQUEST_FNF) { + if (st != SignalType.ON_ERROR) { + if (!once2.getAndSet(true)) { + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode( + rule.allocator, 1, new IllegalStateException("test"))); + } else { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeComplete(rule.allocator, 3)); + } + } + } + }); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + Publisher publisher = request.apply(rule.client, source); + if (publisher instanceof Mono) { + ((Mono) publisher) + .retryWhen(Retry.backoff(3, Duration.ofMillis(100))) + .subscribe(assertSubscriber); + } else { + ((Flux) publisher) + .retryWhen(Retry.backoff(3, Duration.ofMillis(100))) + .subscribe(assertSubscriber); + } + + assertSubscriber.request(1); + + if (requestType == FrameType.REQUEST_CHANNEL) { + rule.delayer.run(); + } + + assertSubscriber.await(Duration.ofSeconds(10)).assertComplete(); + + if (requestType == FrameType.REQUEST_CHANNEL) { + ArrayList sent = new ArrayList<>(rule.connection.getSent()); + Assertions.assertThat(sent).hasSize(4); + for (int i = 0; i < sent.size(); i++) { + if (i % 2 == 0) { + Assertions.assertThat(sent.get(i)) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(sent.get(i)) + .matches(bb -> FrameHeaderCodec.frameType(bb).equals(FrameType.COMPLETE)) + .matches(ReferenceCounted::release); + } + } + } else { + Collection sent = rule.connection.getSent(); + Assertions.assertThat(sent) + .hasSize( + requestType == FrameType.REQUEST_FNF || requestType == FrameType.METADATA_PUSH + ? 1 + : 2) + .allMatch(bb -> FrameHeaderCodec.frameType(bb).equals(requestType)) + .allMatch(ReferenceCounted::release); + } + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldBeAbleToResolveOriginalSource() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + rule.client.source().subscribe(assertSubscriber); + + assertSubscriber.assertNotTerminated(); + + rule.delayer.run(); + + assertSubscriber.request(1); + + assertSubscriber.assertTerminated().assertValueCount(1); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.source().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertValueCount(1); + + Assertions.assertThat(assertSubscriber1.values()).isEqualTo(assertSubscriber.values()); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldDisposeOriginalSource() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.source().subscribe(assertSubscriber1); + + assertSubscriber1 + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Disposed"); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldReceiveOnCloseNotificationOnDisposeOriginalSource() { + Sinks.Empty onCloseDelayer = Sinks.empty(); + ClientSocketRule rule = + new ClientSocketRule() { + @Override + protected RSocket newRSocket() { + return new RSocketProxy(super.newRSocket()) { + @Override + public Mono onClose() { + return super.onClose().and(onCloseDelayer.asMono()); + } + }; + } + }; + rule.init(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber onCloseSubscriber = AssertSubscriber.create(); + + rule.client.onClose().subscribe(onCloseSubscriber); + onCloseSubscriber.assertNotTerminated(); + + onCloseDelayer.tryEmitEmpty(); + + onCloseSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldResolveOnStartSource() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber); + rule.delayer.run(); + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldNotStartIfAlreadyDisposed() { + Assertions.assertThat(rule.client.connect()).isTrue(); + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.delayer.run(); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.connect()).isFalse(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + + assertSubscriber1.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldBeRestartedIfSourceWasClosed() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + AssertSubscriber terminateSubscriber = AssertSubscriber.create(); + + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber); + rule.client.onClose().subscribe(terminateSubscriber); + + rule.delayer.run(); + + assertSubscriber.assertTerminated().assertValueCount(1); + + rule.socket.dispose(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + terminateSubscriber.assertNotTerminated(); + Assertions.assertThat(rule.client.isDisposed()).isFalse(); + + rule.connection = new TestDuplexConnection(rule.allocator); + rule.socket = rule.newRSocket(); + rule.producer = Sinks.one(); + + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(); + + Assertions.assertThat(rule.client.connect()).isTrue(); + rule.client.source().subscribe(assertSubscriber2); + + rule.delayer.run(); + + assertSubscriber2.assertTerminated().assertValueCount(1); + + rule.client.dispose(); + + terminateSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(rule.client.connect()).isFalse(); + + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + rule.allocator.assertHasNoLeaks(); + } + + @Test + public void shouldDisposeOriginalSourceIfRacing() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + + rule.init(); + + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.client.source().subscribe(assertSubscriber); + + RaceTestUtils.race(rule.delayer, () -> rule.client.dispose()); + + assertSubscriber.assertTerminated(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.source().subscribe(assertSubscriber1); + + assertSubscriber1 + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Disposed"); + + ByteBuf buf; + while ((buf = rule.connection.pollFrame()) != null) { + FrameAssert.assertThat(buf).hasStreamIdZero().hasData("Disposed").hasNoLeaks(); + } + + rule.allocator.assertHasNoLeaks(); + } + } + + @Test + public void shouldStartOriginalSourceOnceIfRacing() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule rule = new ClientSocketRule(); + + rule.init(); + + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + RaceTestUtils.race( + () -> rule.client.source().subscribe(assertSubscriber), () -> rule.client.connect()); + + Assertions.assertThat(rule.producer.currentSubscriberCount()).isOne(); + + rule.delayer.run(); + + assertSubscriber.assertTerminated(); + + rule.client.dispose(); + + Assertions.assertThat(rule.client.isDisposed()).isTrue(); + Assertions.assertThat(rule.socket.isDisposed()).isTrue(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + + rule.client.onClose().subscribe(assertSubscriber1); + FrameAssert.assertThat(rule.connection.awaitFrame()) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + + assertSubscriber1.assertTerminated().assertComplete(); + + rule.allocator.assertHasNoLeaks(); + } + } + + public static class ClientSocketRule extends AbstractSocketRule { + + protected RSocketClient client; + protected Runnable delayer; + protected Sinks.One producer; + + protected Sinks.Empty thisClosedSink; + + @Override + protected void doInit() { + super.doInit(); + delayer = () -> producer.tryEmitValue(socket); + producer = Sinks.one(); + client = + new DefaultRSocketClient( + Mono.defer( + () -> + producer + .asMono() + .doOnCancel(() -> socket.dispose()) + .doOnDiscard(Disposable.class, Disposable::dispose))); + } + + @Override + protected RSocket newRSocket() { + this.thisClosedSink = Sinks.empty(); + return new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + maxFrameLength, + maxInboundPayloadSize, + Integer.MAX_VALUE, + Integer.MAX_VALUE, + null, + __ -> null, + null, + thisClosedSink, + thisClosedSink.asMono()); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java new file mode 100644 index 000000000..f5422a4bf --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/FireAndForgetRequesterMonoTest.java @@ -0,0 +1,448 @@ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +public class FireAndForgetRequesterMonoTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /** + * General StateMachine transition test. No Fragmentation enabled In this test we check that the + * given instance of FireAndForgetMono subscribes, and then sends frame immediately + */ + @ParameterizedTest + @MethodSource("frameSent") + public void frameShouldBeSentOnSubscription(Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Payload payload = genericPayload(activeStreams.getAllocator()); + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + // should not add anything to map + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + final ByteBuf frame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectNothing(); + } + + /** + * General StateMachine transition test. Fragmentation enabled In this test we check that the + * given instance of FireAndForgetMono subscribes, and then sends all fragments as a separate + * frame immediately + */ + @ParameterizedTest + @MethodSource("frameSent") + public void frameFragmentsShouldBeSentOnSubscription( + Consumer monoConsumer) { + final int mtu = 64; + final TestRequesterResponderSupport streamManager = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + // should not add anything to map + streamManager.assertNoActiveStreams(); + stateAssert.isTerminated(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOf(metadata, 52)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOfRange(metadata, 52, 65)) + .hasData(Arrays.copyOf(data, 39)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET) // 64 - 6 (frame headers) - 3 frame length (no metadata - no length) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 39, 94)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(35) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 94, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> frameSent() { + return Stream.of( + (s) -> StepVerifier.create(s).expectSubscription().expectComplete().verify(), + FireAndForgetRequesterMono::block); + } + + /** + * RefCnt validation test. Should send error if RefCnt is incorrect and frame has already been + * released Note: ONCE state should be 0 + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject(FrameType.REQUEST_FNF, new IllegalReferenceCountException("refCnt: 0")) + .expectNothing(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * Check that proper payload size validation is enabled so in case payload fragmentation is + * disabled we will not send anything bigger that 16MB (see specification for MAX frame size) + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + testRequestInterceptor + .expectOnReject( + FrameType.REQUEST_FNF, + new IllegalArgumentException( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK))) + .expectNothing(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that frame will not be sent if we dont have availability for that. Options: 1. RSocket + * disposed / Connection Error, so all racing on existing interactions should be terminated as + * well 2. RSocket tries to use lease and end-ups with no available leases + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RuntimeException exception = new RuntimeException("test"); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(exception, testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + final Payload payload = genericPayload(allocator); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + stateAssert.isUnsubscribed(); + streamManager.assertNoActiveStreams(); + + monoConsumer.accept(fireAndForgetRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + testRequestInterceptor.expectOnReject(FrameType.REQUEST_FNF, exception).expectNothing(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + fireAndForgetRequesterMono -> + Assertions.assertThatThrownBy(fireAndForgetRequesterMono::block) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + /** Ensures single subscription happens in case of racing */ + @Test + public void shouldSubscribeExactlyOnce1() { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport streamManager = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = streamManager.getAllocator(); + final TestDuplexConnection sender = streamManager.getDuplexConnection(); + + for (int i = 1; i < 50000; i += 2) { + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, streamManager); + final StateAssert stateAssert = + StateAssert.assertThat(FireAndForgetRequesterMono.STATE, fireAndForgetRequesterMono); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + () -> { + AtomicReference atomicReference = new AtomicReference<>(); + fireAndForgetRequesterMono.subscribe(null, atomicReference::set); + Throwable throwable = atomicReference.get(); + if (throwable != null) { + throw Exceptions.propagate(throwable); + } + }, + fireAndForgetRequesterMono::block)) + .matches( + t -> { + Assertions.assertThat(t) + .hasMessageContaining("FireAndForgetMono allows only a single Subscriber"); + return true; + }); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_FNF) + .hasClientSideStreamId() + .hasStreamId(i) + .hasNoLeaks(); + + stateAssert.isTerminated(); + streamManager.assertNoActiveStreams(); + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); + } + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final FireAndForgetRequesterMono fireAndForgetRequesterMono = + new FireAndForgetRequesterMono(payload, testRequesterResponderSupport); + + Assertions.assertThat(Scannable.from(fireAndForgetRequesterMono).name()) + .isEqualTo("source(FireAndForgetMono)"); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java new file mode 100644 index 000000000..5be59235c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/KeepAliveTest.java @@ -0,0 +1,420 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.keepalive.KeepAliveHandler.DefaultKeepAliveHandler; +import static io.rsocket.keepalive.KeepAliveHandler.ResumableKeepAliveHandler; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.FrameAssert; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ConnectionErrorException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.resume.RSocketSession; +import io.rsocket.resume.ResumableDuplexConnection; +import io.rsocket.resume.ResumeStateHolder; +import io.rsocket.test.util.TestDuplexConnection; +import java.time.Duration; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; + +public class KeepAliveTest { + private static final int KEEP_ALIVE_INTERVAL = 100; + private static final int KEEP_ALIVE_TIMEOUT = 1000; + private static final int RESUMABLE_KEEP_ALIVE_TIMEOUT = 200; + + VirtualTimeScheduler virtualTimeScheduler; + + @BeforeEach + public void setUp() { + virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + } + + @AfterEach + public void tearDown() { + VirtualTimeScheduler.reset(); + } + + static RSocketState requester(int tickPeriod, int timeout) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); + Sinks.Empty empty = Sinks.empty(); + RSocketRequester rSocket = + new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + tickPeriod, + timeout, + new DefaultKeepAliveHandler(), + r -> null, + null, + empty, + empty.asMono()); + return new RSocketState(rSocket, allocator, connection, empty); + } + + static ResumableRSocketState resumableRequester(int tickPeriod, int timeout) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection connection = new TestDuplexConnection(allocator); + ResumableDuplexConnection resumableConnection = + new ResumableDuplexConnection( + "test", + Unpooled.EMPTY_BUFFER, + connection, + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 10_000)); + Sinks.Empty onClose = Sinks.empty(); + + RSocketRequester rSocket = + new RSocketRequester( + resumableConnection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + tickPeriod, + timeout, + new ResumableKeepAliveHandler( + resumableConnection, + Mockito.mock(RSocketSession.class), + Mockito.mock(ResumeStateHolder.class)), + __ -> null, + null, + onClose, + onClose.asMono()); + return new ResumableRSocketState(rSocket, connection, resumableConnection, onClose, allocator); + } + + @Test + void rSocketNotDisposedOnPresentKeepAlives() { + RSocketState requesterState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + + TestDuplexConnection connection = requesterState.connection(); + + Disposable disposable = + Flux.interval(Duration.ofMillis(KEEP_ALIVE_INTERVAL)) + .subscribe( + n -> + connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode( + requesterState.allocator, true, 0, Unpooled.EMPTY_BUFFER))); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_TIMEOUT * 2)); + + RSocket rSocket = requesterState.rSocket(); + + Assertions.assertThat(rSocket.isDisposed()).isFalse(); + + disposable.dispose(); + + requesterState.connection.dispose(); + requesterState.rSocket.dispose(); + + Assertions.assertThat(requesterState.connection.getSent()).allMatch(ByteBuf::release); + + requesterState.allocator.assertHasNoLeaks(); + } + + @Test + void noKeepAlivesSentAfterRSocketDispose() { + RSocketState requesterState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + + requesterState.rSocket().dispose(); + + Duration duration = Duration.ofMillis(500); + + virtualTimeScheduler.advanceTimeBy(duration); + + FrameAssert.assertThat(requesterState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasData("Disposed") + .hasNoLeaks(); + FrameAssert.assertThat(requesterState.connection.pollFrame()).isNull(); + requesterState.allocator.assertHasNoLeaks(); + } + + @Test + void rSocketDisposedOnMissingKeepAlives() { + RSocketState requesterState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + + RSocket rSocket = requesterState.rSocket(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_TIMEOUT * 2)); + + Assertions.assertThat(rSocket.isDisposed()).isTrue(); + rSocket + .onClose() + .as(StepVerifier::create) + .expectError(ConnectionErrorException.class) + .verify(Duration.ofMillis(100)); + + Assertions.assertThat(requesterState.connection.getSent()).allMatch(ByteBuf::release); + + requesterState.allocator.assertHasNoLeaks(); + } + + @Test + void clientRequesterSendsKeepAlives() { + RSocketState RSocketState = requester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + TestDuplexConnection connection = RSocketState.connection(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + this.keepAliveFrameWithRespondFlag(connection.pollFrame()); + + RSocketState.rSocket.dispose(); + FrameAssert.assertThat(connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasData("Disposed") + .hasNoLeaks(); + RSocketState.connection.dispose(); + + RSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void requesterRespondsToKeepAlives() { + RSocketState rSocketState = requester(100_000, 100_000); + TestDuplexConnection connection = rSocketState.connection(); + Duration duration = Duration.ofMillis(100); + Mono.delay(duration) + .subscribe( + l -> + connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode( + rSocketState.allocator, true, 0, Unpooled.EMPTY_BUFFER))); + + virtualTimeScheduler.advanceTimeBy(duration); + FrameAssert.assertThat(connection.awaitFrame()) + .typeOf(FrameType.KEEPALIVE) + .matches(this::keepAliveFrameWithoutRespondFlag); + + rSocketState.rSocket.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + rSocketState.connection.dispose(); + + rSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void resumableRequesterNoKeepAlivesAfterDisconnect() { + ResumableRSocketState rSocketState = + resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + TestDuplexConnection testConnection = rSocketState.connection(); + ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); + + resumableDuplexConnection.disconnect(); + + Duration duration = Duration.ofMillis(KEEP_ALIVE_INTERVAL * 5); + virtualTimeScheduler.advanceTimeBy(duration); + Assertions.assertThat(testConnection.pollFrame()).isNull(); + + rSocketState.rSocket.dispose(); + rSocketState.connection.dispose(); + + rSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void resumableRequesterKeepAlivesAfterReconnect() { + ResumableRSocketState rSocketState = + resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + ResumableDuplexConnection resumableDuplexConnection = rSocketState.resumableDuplexConnection(); + resumableDuplexConnection.disconnect(); + TestDuplexConnection newTestConnection = new TestDuplexConnection(rSocketState.alloc()); + resumableDuplexConnection.connect(newTestConnection); + // resumableDuplexConnection.(0, 0, ignored -> Mono.empty()); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(KEEP_ALIVE_INTERVAL)); + + FrameAssert.assertThat(newTestConnection.awaitFrame()) + .typeOf(FrameType.KEEPALIVE) + .hasStreamIdZero() + .hasNoLeaks(); + + rSocketState.rSocket.dispose(); + FrameAssert.assertThat(newTestConnection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + FrameAssert.assertThat(newTestConnection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Connection Closed Unexpectedly") // API limitations + .hasNoLeaks(); + newTestConnection.dispose(); + + rSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void resumableRequesterNoKeepAlivesAfterDispose() { + ResumableRSocketState rSocketState = + resumableRequester(KEEP_ALIVE_INTERVAL, KEEP_ALIVE_TIMEOUT); + rSocketState.rSocket().dispose(); + Duration duration = Duration.ofMillis(500); + StepVerifier.create(Flux.from(rSocketState.connection().getSentAsPublisher()).take(duration)) + .then(() -> virtualTimeScheduler.advanceTimeBy(duration)) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + rSocketState.rSocket.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Disposed") + .hasNoLeaks(); + rSocketState.connection.dispose(); + FrameAssert.assertThat(rSocketState.connection.pollFrame()) + .typeOf(FrameType.ERROR) + .hasStreamIdZero() + .hasData("Connection Closed Unexpectedly") + .hasNoLeaks(); + + rSocketState.allocator.assertHasNoLeaks(); + } + + @Test + void resumableRSocketsNotDisposedOnMissingKeepAlives() throws InterruptedException { + ResumableRSocketState resumableRequesterState = + resumableRequester(KEEP_ALIVE_INTERVAL, RESUMABLE_KEEP_ALIVE_TIMEOUT); + RSocket rSocket = resumableRequesterState.rSocket(); + TestDuplexConnection connection = resumableRequesterState.connection(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(500)); + + Assertions.assertThat(rSocket.isDisposed()).isFalse(); + Assertions.assertThat(connection.isDisposed()).isTrue(); + + Assertions.assertThat(resumableRequesterState.connection.getSent()).allMatch(ByteBuf::release); + + resumableRequesterState.connection.dispose(); + resumableRequesterState.rSocket.dispose(); + + resumableRequesterState.allocator.assertHasNoLeaks(); + } + + private boolean keepAliveFrame(ByteBuf frame) { + return FrameHeaderCodec.frameType(frame) == FrameType.KEEPALIVE; + } + + private boolean keepAliveFrameWithRespondFlag(ByteBuf frame) { + return keepAliveFrame(frame) && KeepAliveFrameCodec.respondFlag(frame) && frame.release(); + } + + private boolean keepAliveFrameWithoutRespondFlag(ByteBuf frame) { + return keepAliveFrame(frame) && !KeepAliveFrameCodec.respondFlag(frame) && frame.release(); + } + + static class RSocketState { + private final RSocket rSocket; + private final TestDuplexConnection connection; + private final LeaksTrackingByteBufAllocator allocator; + private final Sinks.Empty onClose; + + public RSocketState( + RSocket rSocket, + LeaksTrackingByteBufAllocator allocator, + TestDuplexConnection connection, + Sinks.Empty onClose) { + this.rSocket = rSocket; + this.connection = connection; + this.allocator = allocator; + this.onClose = onClose; + } + + public TestDuplexConnection connection() { + return connection; + } + + public RSocket rSocket() { + return rSocket; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + } + + static class ResumableRSocketState { + private final RSocket rSocket; + private final TestDuplexConnection connection; + private final ResumableDuplexConnection resumableDuplexConnection; + private final LeaksTrackingByteBufAllocator allocator; + private final Sinks.Empty onClose; + + public ResumableRSocketState( + RSocket rSocket, + TestDuplexConnection connection, + ResumableDuplexConnection resumableDuplexConnection, + Sinks.Empty onClose, + LeaksTrackingByteBufAllocator allocator) { + this.rSocket = rSocket; + this.connection = connection; + this.resumableDuplexConnection = resumableDuplexConnection; + this.onClose = onClose; + this.allocator = allocator; + } + + public TestDuplexConnection connection() { + return connection; + } + + public ResumableDuplexConnection resumableDuplexConnection() { + return resumableDuplexConnection; + } + + public RSocket rSocket() { + return rSocket; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java new file mode 100644 index 000000000..707d42afe --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/PayloadValidationUtilsTest.java @@ -0,0 +1,142 @@ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_SIZE; + +import io.rsocket.Payload; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +class PayloadValidationUtilsTest { + + @Test + void shouldBeValidFrameWithNoFragmentation() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] data = new byte[maxFrameLength - FRAME_LENGTH_SIZE - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation1() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] data = + new byte[maxFrameLength - FRAME_LENGTH_SIZE - Integer.BYTES - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isTrue(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] data = new byte[maxFrameLength - FRAME_LENGTH_SIZE - FrameHeaderCodec.size() + 1]; + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation0() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[maxFrameLength / 2]; + byte[] data = + new byte + [(maxFrameLength / 2 + 1) + - FRAME_LENGTH_SIZE + - FrameHeaderCodec.size() + - FrameHeaderCodec.size()]; + ThreadLocalRandom.current().nextBytes(data); + ThreadLocalRandom.current().nextBytes(metadata); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + } + + @Test + void shouldBeInValidFrameWithNoFragmentation1() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isFalse(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isFalse(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation2() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(0, maxFrameLength, payload, false)) + .isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation3() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, false)) + .isTrue(); + } + + @Test + void shouldBeValidFrameWithNoFragmentation4() { + int maxFrameLength = + ThreadLocalRandom.current().nextInt(64, FrameLengthCodec.FRAME_LENGTH_MASK); + byte[] metadata = new byte[1]; + byte[] data = new byte[1]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, true)) + .isTrue(); + Assertions.assertThat(PayloadValidationUtils.isValid(64, maxFrameLength, payload, false)) + .isTrue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java new file mode 100644 index 000000000..7cf12a81e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketConnectorTest.java @@ -0,0 +1,308 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.retry.Retry; + +public class RSocketConnectorTest { + + @ParameterizedTest + @ValueSource(strings = {"KEEPALIVE", "REQUEST_RESPONSE"}) + public void unexpectedFramesBeforeResumeOKFrame(String frameType) { + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.create() + .resume(new Resume().retry(Retry.indefinitely())) + .connect(transport) + .block(); + + final TestDuplexConnection duplexConnection = transport.testConnection(); + + duplexConnection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(duplexConnection.alloc(), false, 1, Unpooled.EMPTY_BUFFER)); + FrameAssert.assertThat(duplexConnection.pollFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + FrameAssert.assertThat(duplexConnection.pollFrame()).isNull(); + + duplexConnection.dispose(); + + final TestDuplexConnection duplexConnection2 = transport.testConnection(); + + final ByteBuf frame; + switch (frameType) { + case "KEEPALIVE": + frame = + KeepAliveFrameCodec.encode(duplexConnection2.alloc(), false, 1, Unpooled.EMPTY_BUFFER); + break; + case "REQUEST_RESPONSE": + default: + frame = + RequestResponseFrameCodec.encode( + duplexConnection2.alloc(), 2, false, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + } + duplexConnection2.addToReceivedBuffer(frame); + + StepVerifier.create(duplexConnection2.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection2.pollFrame()) + .typeOf(FrameType.RESUME) + .hasStreamIdZero() + .hasNoLeaks(); + + FrameAssert.assertThat(duplexConnection2.pollFrame()) + .isNotNull() + .typeOf(FrameType.ERROR) + .hasData("RESUME_OK frame must be received before any others") + .hasStreamIdZero() + .hasNoLeaks(); + + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresThatSetupPayloadCanBeRetained() { + AtomicReference retainedSetupPayload = new AtomicReference<>(); + TestClientTransport transport = new TestClientTransport(); + + ByteBuf data = transport.alloc().buffer(); + + data.writeCharSequence("data", CharsetUtil.UTF_8); + + RSocketConnector.create() + .setupPayload(ByteBufPayload.create(data)) + .acceptor( + (setup, sendingSocket) -> { + retainedSetupPayload.set(setup.retain()); + return Mono.just(new RSocket() {}); + }) + .connect(transport) + .block(); + + assertThat(transport.testConnection().getSent()) + .hasSize(1) + .first() + .matches( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return !payload.hasMetadata() && payload.getDataUtf8().equals("data"); + }) + .matches(buf -> buf.refCnt() == 2) + .matches( + buf -> { + buf.release(); + return buf.refCnt() == 1; + }); + + ConnectionSetupPayload setup = retainedSetupPayload.get(); + String dataUtf8 = setup.getDataUtf8(); + assertThat("data".equals(dataUtf8) && setup.release()).isTrue(); + assertThat(setup.refCnt()).isZero(); + + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresThatMonoFromRSocketConnectorCanBeUsedForMultipleSubscriptions() { + Payload setupPayload = ByteBufPayload.create("TestData", "TestMetadata"); + assertThat(setupPayload.refCnt()).isOne(); + + // Keep the data and metadata around so we can try changing them independently + ByteBuf dataBuf = setupPayload.data(); + ByteBuf metadataBuf = setupPayload.metadata(); + dataBuf.retain(); + metadataBuf.retain(); + + TestClientTransport testClientTransport = new TestClientTransport(); + Mono connectionMono = + RSocketConnector.create().setupPayload(setupPayload).connect(testClientTransport); + + connectionMono + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofMillis(100)); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData") + && payload.getMetadataUtf8().equals("TestMetadata"); + }) + .allMatch(ReferenceCounted::release); + + connectionMono + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofMillis(100)); + + // Changing the original data and metadata should not impact the SetupPayload + dataBuf.writerIndex(dataBuf.readerIndex()); + dataBuf.writeChar('d'); + dataBuf.release(); + + metadataBuf.writerIndex(metadataBuf.readerIndex()); + metadataBuf.writeChar('m'); + metadataBuf.release(); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData") + && payload.getMetadataUtf8().equals("TestMetadata"); + }) + .allMatch( + byteBuf -> { + System.out.println("calling release " + byteBuf.refCnt()); + return byteBuf.release(); + }); + assertThat(setupPayload.refCnt()).isZero(); + + testClientTransport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresThatSetupPayloadProvidedAsMonoIsReleased() { + List saved = new ArrayList<>(); + AtomicLong subscriptions = new AtomicLong(); + Mono setupPayloadMono = + Mono.create( + sink -> { + final long subscriptionN = subscriptions.getAndIncrement(); + Payload payload = + ByteBufPayload.create("TestData" + subscriptionN, "TestMetadata" + subscriptionN); + saved.add(payload); + sink.success(payload); + }); + + TestClientTransport testClientTransport = new TestClientTransport(); + Mono connectionMono = + RSocketConnector.create().setupPayload(setupPayloadMono).connect(testClientTransport); + + connectionMono + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofMillis(100)); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData0") + && payload.getMetadataUtf8().equals("TestMetadata0"); + }) + .allMatch(ReferenceCounted::release); + + connectionMono + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofMillis(100)); + + assertThat(testClientTransport.testConnection().getSent()) + .hasSize(1) + .allMatch( + bb -> { + DefaultConnectionSetupPayload payload = new DefaultConnectionSetupPayload(bb); + return payload.getDataUtf8().equals("TestData1") + && payload.getMetadataUtf8().equals("TestMetadata1"); + }) + .allMatch(ReferenceCounted::release); + + assertThat(saved) + .as("Metadata and data were consumed and released as slices") + .allMatch( + payload -> + payload.refCnt() == 1 + && payload.data().refCnt() == 0 + && payload.metadata().refCnt() == 0); + + testClientTransport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeLessThenMtu() { + RSocketConnector.create() + .fragment(128) + .connect(new TestClientTransport().withMaxFrameLength(64)) + .as(StepVerifier::create) + .expectErrorMessage( + "Configured maximumTransmissionUnit[128] exceeds configured maxFrameLength[64]") + .verify(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPayloadSize() { + RSocketConnector.create() + .maxInboundPayloadSize(128) + .connect(new TestClientTransport().withMaxFrameLength(256)) + .as(StepVerifier::create) + .expectErrorMessage("Configured maxFrameLength[256] exceeds maxPayloadSize[128]") + .verify(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPossibleFrameLength() { + RSocketConnector.create() + .connect(new TestClientTransport().withMaxFrameLength(Integer.MAX_VALUE)) + .as(StepVerifier::create) + .expectErrorMessage( + "Configured maxFrameLength[" + + Integer.MAX_VALUE + + "] " + + "exceeds maxFrameLength limit " + + FRAME_LENGTH_MASK) + .verify(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java new file mode 100644 index 000000000..a461833d3 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketLeaseTest.java @@ -0,0 +1,724 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.LEASE; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.SETUP; +import static org.assertj.core.data.Offset.offset; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.LeaseFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.lease.Lease; +import io.rsocket.lease.MissingLeaseException; +import io.rsocket.plugins.InitializingInterceptorRegistry; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestServerTransport; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.Collection; +import java.util.function.BiFunction; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +class RSocketLeaseTest { + private static final String TAG = "test"; + + private RSocket rSocketRequester; + private ResponderLeaseTracker responderLeaseTracker; + private LeaksTrackingByteBufAllocator byteBufAllocator; + private TestDuplexConnection connection; + private RSocketResponder rSocketResponder; + private RSocket mockRSocketHandler; + + private Sinks.Many leaseSender = Sinks.many().multicast().onBackpressureBuffer(); + private RequesterLeaseTracker requesterLeaseTracker; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + @BeforeEach + void setUp() { + PayloadDecoder payloadDecoder = PayloadDecoder.DEFAULT; + byteBufAllocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + connection = new TestDuplexConnection(byteBufAllocator); + requesterLeaseTracker = new RequesterLeaseTracker(TAG, 0); + responderLeaseTracker = new ResponderLeaseTracker(TAG, connection, () -> leaseSender.asFlux()); + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); + + ClientServerInputMultiplexer multiplexer = + new ClientServerInputMultiplexer(connection, new InitializingInterceptorRegistry(), true); + rSocketRequester = + new RSocketRequester( + multiplexer.asClientConnection(), + payloadDecoder, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + requesterLeaseTracker, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); + + mockRSocketHandler = mock(RSocket.class); + when(mockRSocketHandler.metadataPush(any())) + .then( + a -> { + Payload payload = a.getArgument(0); + payload.release(); + return Mono.empty(); + }); + when(mockRSocketHandler.fireAndForget(any())) + .then( + a -> { + Payload payload = a.getArgument(0); + payload.release(); + return Mono.empty(); + }); + when(mockRSocketHandler.requestResponse(any())) + .then( + a -> { + Payload payload = a.getArgument(0); + payload.release(); + return Mono.empty(); + }); + when(mockRSocketHandler.requestStream(any())) + .then( + a -> { + Payload payload = a.getArgument(0); + payload.release(); + return Flux.empty(); + }); + when(mockRSocketHandler.requestChannel(any())) + .then( + a -> { + Publisher payloadPublisher = a.getArgument(0); + return Flux.from(payloadPublisher) + .doOnNext(ReferenceCounted::release) + .transform( + Operators.lift( + (__, actual) -> + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + actual.onSubscribe(this); + } + + @Override + protected void hookOnComplete() { + actual.onComplete(); + } + + @Override + protected void hookOnError(Throwable throwable) { + actual.onError(throwable); + } + })); + }); + + rSocketResponder = + new RSocketResponder( + multiplexer.asServerConnection(), + mockRSocketHandler, + payloadDecoder, + responderLeaseTracker, + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + __ -> null, + otherClosedSink); + } + + @AfterEach + void tearDownAndCheckForLeaks() { + byteBufAllocator.assertHasNoLeaks(); + } + + @Test + public void serverRSocketFactoryRejectsUnsupportedLease() { + Payload payload = DefaultPayload.create(DefaultPayload.EMPTY_BUFFER); + ByteBuf setupFrame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + true, + 1000, + 30_000, + "application/octet-stream", + "application/octet-stream", + payload); + + TestServerTransport transport = new TestServerTransport(); + RSocketServer.create().bind(transport).block(); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer(setupFrame); + + Collection sent = connection.getSent(); + Assertions.assertThat(sent).hasSize(1); + ByteBuf error = sent.iterator().next(); + Assertions.assertThat(FrameHeaderCodec.frameType(error)).isEqualTo(ERROR); + Assertions.assertThat(Exceptions.from(0, error).getMessage()) + .isEqualTo("lease is not supported"); + error.release(); + connection.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void clientRSocketFactorySetsLeaseFlag() { + TestClientTransport clientTransport = new TestClientTransport(); + try { + RSocketConnector.create().lease().connect(clientTransport).block(); + Collection sent = clientTransport.testConnection().getSent(); + Assertions.assertThat(sent).hasSize(1); + ByteBuf setup = sent.iterator().next(); + Assertions.assertThat(FrameHeaderCodec.frameType(setup)).isEqualTo(SETUP); + Assertions.assertThat(SetupFrameCodec.honorLease(setup)).isTrue(); + setup.release(); + } finally { + clientTransport.testConnection().dispose(); + clientTransport.alloc().assertHasNoLeaks(); + } + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterMissingLeaseRequestsAreRejected( + BiFunction> interaction) { + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.0, offset(1e-2)); + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + StepVerifier.create(interaction.apply(rSocketRequester, payload1)) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterPresentLeaseRequestsAreAccepted( + BiFunction> interaction, FrameType frameType) { + ByteBuf frame = leaseFrame(5_000, 2, Unpooled.EMPTY_BUFFER); + requesterLeaseTracker.handleLeaseFrame(frame); + + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(1.0, offset(1e-2)); + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + Flux.from(interaction.apply(rSocketRequester, payload1)) + .as(StepVerifier::create) + .then( + () -> { + if (frameType != REQUEST_FNF) { + connection.addToReceivedBuffer( + PayloadFrameCodec.encodeComplete(byteBufAllocator, 1)); + } + }) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + if (frameType == REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == frameType) + .matches(ReferenceCounted::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == frameType) + .matches(ReferenceCounted::release); + } + + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.5, offset(1e-2)); + + Assertions.assertThat(frame.release()).isTrue(); + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + @SuppressWarnings({"rawtypes", "unchecked"}) + void requesterDepletedAllowedLeaseRequestsAreRejected( + BiFunction> interaction, FrameType interactionType) { + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + ByteBuf leaseFrame = leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER); + requesterLeaseTracker.handleLeaseFrame(leaseFrame); + + double initialAvailability = requesterLeaseTracker.availability(); + Publisher request = interaction.apply(rSocketRequester, payload1); + + // ensures that lease is not used until the frame is sent + Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseTracker.availability()); + Assertions.assertThat(connection.getSent()).hasSize(0); + + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + request.subscribe(assertSubscriber); + + // if request is FNF, then request frame is sent on subscribe + // otherwise we need to make request(1) + if (interactionType != REQUEST_FNF) { + Assertions.assertThat(initialAvailability).isEqualTo(requesterLeaseTracker.availability()); + Assertions.assertThat(connection.getSent()).hasSize(0); + + assertSubscriber.request(1); + } + + // ensures availability is changed and lease is used only up on frame sending + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(0.0, offset(1e-2)); + + if (interactionType == REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) + .matches(ReferenceCounted::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == interactionType) + .matches(ReferenceCounted::release); + } + + ByteBuf buffer2 = byteBufAllocator.buffer(); + buffer2.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload2 = ByteBufPayload.create(buffer2); + Flux.from(interaction.apply(rSocketRequester, payload2)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(leaseFrame.release()).isTrue(); + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + void requesterExpiredLeaseRequestsAreRejected( + BiFunction> interaction) { + ByteBuf frame = leaseFrame(50, 1, Unpooled.EMPTY_BUFFER); + requesterLeaseTracker.handleLeaseFrame(frame); + + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + Flux.defer(() -> interaction.apply(rSocketRequester, payload1)) + .delaySubscription(Duration.ofMillis(200)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(frame.release()).isTrue(); + + byteBufAllocator.assertHasNoLeaks(); + } + + @Test + void requesterAvailabilityRespectsTransport() { + ByteBuf frame = leaseFrame(5_000, 1, Unpooled.EMPTY_BUFFER); + try { + + requesterLeaseTracker.handleLeaseFrame(frame); + double unavailable = 0.0; + connection.setAvailability(unavailable); + Assertions.assertThat(rSocketRequester.availability()).isCloseTo(unavailable, offset(1e-2)); + } finally { + frame.release(); + } + } + + @ParameterizedTest + @MethodSource("responderInteractions") + void responderMissingLeaseRequestsAreRejected(FrameType frameType) { + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("responderInteractions") + void responderPresentLeaseRequestsAreAccepted(FrameType frameType) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 2)); + + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFireAndForget(1, fnfFrame); + fnfFrame.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + rSocketResponder.handleFrame(requestResponseFrame); + requestResponseFrame.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + rSocketResponder.handleFrame(requestStreamFrame); + requestStreamFrame.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + rSocketResponder.handleFrame(requestChannelFrame); + requestChannelFrame.release(); + break; + } + + switch (frameType) { + case REQUEST_FNF: + Mockito.verify(mockRSocketHandler).fireAndForget(any()); + break; + case REQUEST_RESPONSE: + Mockito.verify(mockRSocketHandler).requestResponse(any()); + break; + case REQUEST_STREAM: + Mockito.verify(mockRSocketHandler).requestStream(any()); + break; + case REQUEST_CHANNEL: + Mockito.verify(mockRSocketHandler).requestChannel(any()); + break; + } + + Assertions.assertThat(connection.getSent()) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) + .matches(ReferenceCounted::release); + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("responderInteractions") + void responderDepletedAllowedLeaseRequestsAreRejected(FrameType frameType) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 1)); + + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + ByteBuf buffer2 = byteBufAllocator.buffer(); + buffer2.writeCharSequence("test2", CharsetUtil.UTF_8); + Payload payload2 = ByteBufPayload.create(buffer2); + + switch (frameType) { + case REQUEST_FNF: + final ByteBuf fnfFrame = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf fnfFrame2 = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(fnfFrame); + rSocketResponder.handleFrame(fnfFrame2); + fnfFrame.release(); + fnfFrame2.release(); + break; + case REQUEST_RESPONSE: + final ByteBuf requestResponseFrame = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, payload1); + final ByteBuf requestResponseFrame2 = + RequestResponseFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, payload2); + rSocketResponder.handleFrame(requestResponseFrame); + rSocketResponder.handleFrame(requestResponseFrame2); + requestResponseFrame.release(); + requestResponseFrame2.release(); + break; + case REQUEST_STREAM: + final ByteBuf requestStreamFrame = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, 1, payload1); + final ByteBuf requestStreamFrame2 = + RequestStreamFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, 1, payload2); + rSocketResponder.handleFrame(requestStreamFrame); + rSocketResponder.handleFrame(requestStreamFrame2); + requestStreamFrame.release(); + requestStreamFrame2.release(); + break; + case REQUEST_CHANNEL: + final ByteBuf requestChannelFrame = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 1, true, 1, payload1); + final ByteBuf requestChannelFrame2 = + RequestChannelFrameCodec.encodeReleasingPayload(byteBufAllocator, 3, true, 1, payload2); + rSocketResponder.handleFrame(requestChannelFrame); + rSocketResponder.handleFrame(requestChannelFrame2); + requestChannelFrame.release(); + requestChannelFrame2.release(); + break; + } + + switch (frameType) { + case REQUEST_FNF: + Mockito.verify(mockRSocketHandler).fireAndForget(any()); + break; + case REQUEST_RESPONSE: + Mockito.verify(mockRSocketHandler).requestResponse(any()); + break; + case REQUEST_STREAM: + Mockito.verify(mockRSocketHandler).requestStream(any()); + break; + case REQUEST_CHANNEL: + Mockito.verify(mockRSocketHandler).requestChannel(any()); + break; + } + + Assertions.assertThat(connection.getSent()) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) + .matches(ReferenceCounted::release); + + if (frameType != REQUEST_FNF) { + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == COMPLETE) + .matches(ReferenceCounted::release); + + Assertions.assertThat(connection.getSent()) + .hasSize(3) + .element(2) + .matches(bb -> FrameHeaderCodec.frameType(bb) == ERROR) + .matches(bb -> Exceptions.from(1, bb) instanceof RejectedException) + .matches(ReferenceCounted::release); + } + + byteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("interactions") + void expiredLeaseRequestsAreRejected(BiFunction> interaction) { + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(50), 1)); + + ByteBuf buffer = byteBufAllocator.buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + Payload payload1 = ByteBufPayload.create(buffer); + + Flux.from(interaction.apply(rSocketRequester, payload1)) + .delaySubscription(Duration.ofMillis(100)) + .as(StepVerifier::create) + .expectError(MissingLeaseException.class) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == LEASE) + .matches(ReferenceCounted::release); + + byteBufAllocator.assertHasNoLeaks(); + } + + @Test + void sendLease() { + ByteBuf metadata = byteBufAllocator.buffer(); + Charset utf8 = StandardCharsets.UTF_8; + String metadataContent = "test"; + metadata.writeCharSequence(metadataContent, utf8); + int ttl = 5_000; + int numberOfRequests = 2; + leaseSender.tryEmitNext(Lease.create(Duration.ofMillis(5_000), 2, metadata)); + + ByteBuf leaseFrame = + connection + .getSent() + .stream() + .filter(f -> FrameHeaderCodec.frameType(f) == FrameType.LEASE) + .findFirst() + .orElseThrow(() -> new IllegalStateException("Lease frame not sent")); + + try { + Assertions.assertThat(LeaseFrameCodec.ttl(leaseFrame)).isEqualTo(ttl); + Assertions.assertThat(LeaseFrameCodec.numRequests(leaseFrame)).isEqualTo(numberOfRequests); + Assertions.assertThat(LeaseFrameCodec.metadata(leaseFrame).toString(utf8)) + .isEqualTo(metadataContent); + } finally { + leaseFrame.release(); + } + } + + // @Test + // void receiveLease() { + // Collection receivedLeases = new ArrayList<>(); + // leaseReceiver.subscribe(lease -> receivedLeases.add(lease)); + // + // ByteBuf metadata = byteBufAllocator.buffer(); + // Charset utf8 = StandardCharsets.UTF_8; + // String metadataContent = "test"; + // metadata.writeCharSequence(metadataContent, utf8); + // int ttl = 5_000; + // int numberOfRequests = 2; + // + // ByteBuf leaseFrame = leaseFrame(ttl, numberOfRequests, metadata).retain(1); + // + // connection.addToReceivedBuffer(leaseFrame); + // + // Assertions.assertThat(receivedLeases.isEmpty()).isFalse(); + // Lease receivedLease = receivedLeases.iterator().next(); + // Assertions.assertThat(receivedLease.getTimeToLiveMillis()).isEqualTo(ttl); + // + // Assertions.assertThat(receivedLease.getStartingAllowedRequests()).isEqualTo(numberOfRequests); + // Assertions.assertThat(receivedLease.metadata().toString(utf8)).isEqualTo(metadataContent); + // + // ReferenceCountUtil.safeRelease(leaseFrame); + // } + + ByteBuf leaseFrame(int ttl, int requests, ByteBuf metadata) { + return LeaseFrameCodec.encode(byteBufAllocator, ttl, requests, metadata); + } + + static Stream interactions() { + return Stream.of( + Arguments.of( + (BiFunction>) RSocket::fireAndForget, + FrameType.REQUEST_FNF), + Arguments.of( + (BiFunction>) RSocket::requestResponse, + FrameType.REQUEST_RESPONSE), + Arguments.of( + (BiFunction>) RSocket::requestStream, + FrameType.REQUEST_STREAM), + Arguments.of( + (BiFunction>) + (rSocket, payload) -> rSocket.requestChannel(Mono.just(payload)), + FrameType.REQUEST_CHANNEL)); + } + + static Stream responderInteractions() { + return Stream.of( + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java new file mode 100644 index 000000000..966fd65f2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketReconnectTest.java @@ -0,0 +1,203 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.FrameAssert; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.transport.ClientTransport; +import java.io.UncheckedIOException; +import java.time.Duration; +import java.util.Iterator; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.Exceptions; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RSocketReconnectTest { + + private Queue retries = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldBeASharedReconnectableInstanceOfRSocketMono() throws InterruptedException { + TestClientTransport[] testClientTransport = + new TestClientTransport[] {new TestClientTransport()}; + Mono rSocketMono = + RSocketConnector.create() + .reconnect(Retry.indefinitely()) + .connect(() -> testClientTransport[0]); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + FrameAssert.assertThat(testClientTransport[0].testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + assertThat(rSocket1).isEqualTo(rSocket2); + + testClientTransport[0].testConnection().dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + testClientTransport[0].alloc().assertHasNoLeaks(); + testClientTransport[0] = new TestClientTransport(); + + RSocket rSocket3 = rSocketMono.block(); + RSocket rSocket4 = rSocketMono.block(); + + FrameAssert.assertThat(testClientTransport[0].testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + assertThat(rSocket3).isEqualTo(rSocket4).isNotEqualTo(rSocket2); + + testClientTransport[0].testConnection().dispose(); + rSocket3.onClose().block(Duration.ofSeconds(1)); + testClientTransport[0].alloc().assertHasNoLeaks(); + } + + @Test + @SuppressWarnings({"rawtype"}) + public void shouldBeRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + TestClientTransport transport1 = new TestClientTransport(); + Mockito.when(transport.connect()) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(transport1.connect()); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + RSocket rSocket1 = rSocketMono.block(); + RSocket rSocket2 = rSocketMono.block(); + + assertThat(rSocket1).isEqualTo(rSocket2); + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + + FrameAssert.assertThat(transport1.testConnection().awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + transport1.testConnection().dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + transport1.alloc().assertHasNoLeaks(); + } + + @Test + @SuppressWarnings({"rawtype"}) + public void shouldBeExaustedRetrieableConnectionSharedReconnectableInstanceOfRSocketMono() { + ClientTransport transport = Mockito.mock(ClientTransport.class); + TestClientTransport transport1 = new TestClientTransport(); + Mockito.when(transport.connect()) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenThrow(UncheckedIOException.class) + .thenReturn(transport1.connect()); + Mono rSocketMono = + RSocketConnector.create() + .reconnect( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .doAfterRetry(onRetry())) + .connect(transport); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + Assertions.assertThatThrownBy(rSocketMono::block) + .matches(Exceptions::isRetryExhausted) + .hasCauseInstanceOf(UncheckedIOException.class); + + assertRetries( + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class, + UncheckedIOException.class); + + transport1.alloc().assertHasNoLeaks(); + } + + @Test + public void shouldBeNotBeASharedReconnectableInstanceOfRSocketMono() { + TestClientTransport transport = new TestClientTransport(); + Mono rSocketMono = RSocketConnector.connectWith(transport); + + RSocket rSocket1 = rSocketMono.block(); + TestDuplexConnection connection1 = transport.testConnection(); + + FrameAssert.assertThat(connection1.awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + RSocket rSocket2 = rSocketMono.block(); + TestDuplexConnection connection2 = transport.testConnection(); + + assertThat(rSocket1).isNotEqualTo(rSocket2); + + FrameAssert.assertThat(connection2.awaitFrame()) + .typeOf(FrameType.SETUP) + .hasStreamIdZero() + .hasNoLeaks(); + + connection1.dispose(); + connection2.dispose(); + rSocket1.onClose().block(Duration.ofSeconds(1)); + rSocket2.onClose().block(Duration.ofSeconds(1)); + transport.alloc().assertHasNoLeaks(); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertThat(retries.size()).isEqualTo(exceptions.length); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertThat(retryContext.totalRetries()).isEqualTo(index); + assertThat(retryContext.failure().getClass()).isEqualTo(exceptions[index]); + index++; + } + } + + Consumer onRetry() { + return context -> retries.add(context); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java new file mode 100644 index 000000000..01eb998c7 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterSubscribersTest.java @@ -0,0 +1,206 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.CharsetUtil; +import io.rsocket.FrameAssert; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.DefaultPayload; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.function.Function; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.util.RaceTestUtils; + +class RSocketRequesterSubscribersTest { + + private static final Set REQUEST_TYPES = + new HashSet<>( + Arrays.asList( + FrameType.METADATA_PUSH, + FrameType.REQUEST_FNF, + FrameType.REQUEST_RESPONSE, + FrameType.REQUEST_STREAM, + FrameType.REQUEST_CHANNEL)); + + private LeaksTrackingByteBufAllocator allocator; + private RSocket rSocketRequester; + private TestDuplexConnection connection; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + @AfterEach + void tearDownAndCheckNoLeaks() { + allocator.assertHasNoLeaks(); + } + + @BeforeEach + void setUp() { + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + connection = new TestDuplexConnection(allocator); + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); + rSocketRequester = + new RSocketRequester( + connection, + PayloadDecoder.DEFAULT, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); + } + + @ParameterizedTest + @MethodSource("allInteractions") + @SuppressWarnings({"rawtypes", "unchecked"}) + void singleSubscriber(Function> interaction, FrameType requestType) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + + AssertSubscriber assertSubscriberA = AssertSubscriber.create(); + AssertSubscriber assertSubscriberB = AssertSubscriber.create(); + + response.subscribe(assertSubscriberA); + response.subscribe(assertSubscriberB); + + if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), 1)); + } + + assertSubscriberA.assertTerminated(); + assertSubscriberB.assertTerminated(); + + FrameAssert.assertThat(connection.pollFrame()).typeOf(requestType).hasNoLeaks(); + + if (requestType == FrameType.REQUEST_CHANNEL) { + FrameAssert.assertThat(connection.pollFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } + } + + @ParameterizedTest + @MethodSource("allInteractions") + void singleSubscriberInCaseOfRacing( + Function> interaction, FrameType requestType) { + for (int i = 1; i < 20000; i += 2) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + AssertSubscriber assertSubscriberA = AssertSubscriber.create(); + AssertSubscriber assertSubscriberB = AssertSubscriber.create(); + + RaceTestUtils.race( + () -> response.subscribe(assertSubscriberA), () -> response.subscribe(assertSubscriberB)); + + if (requestType != FrameType.REQUEST_FNF && requestType != FrameType.METADATA_PUSH) { + connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(connection.alloc(), i)); + } + + assertSubscriberA.assertTerminated(); + assertSubscriberB.assertTerminated(); + + Assertions.assertThat(new AssertSubscriber[] {assertSubscriberA, assertSubscriberB}) + .anySatisfy(as -> as.assertError(IllegalStateException.class)); + + if (requestType == FrameType.REQUEST_CHANNEL) { + Assertions.assertThat(connection.getSent()) + .hasSize(2) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))) + .matches(ByteBuf::release); + Assertions.assertThat(connection.getSent()) + .element(1) + .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.COMPLETE) + .matches(ByteBuf::release); + } else { + Assertions.assertThat(connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(bb))) + .matches(ByteBuf::release); + } + connection.clearSendReceiveBuffers(); + } + } + + @ParameterizedTest + @MethodSource("allInteractions") + void singleSubscriberInteractionsAreLazy(Function> interaction) { + Flux response = Flux.from(interaction.apply(rSocketRequester)); + + Assertions.assertThat(connection.getSent().size()).isEqualTo(0); + } + + static long requestFramesCount(Collection frames) { + return frames + .stream() + .filter(frame -> REQUEST_TYPES.contains(FrameHeaderCodec.frameType(frame))) + .count(); + } + + static Stream allInteractions() { + return Stream.of( + Arguments.of( + (Function>) + rSocket -> rSocket.fireAndForget(DefaultPayload.create("test")), + FrameType.REQUEST_FNF), + Arguments.of( + (Function>) + rSocket -> rSocket.requestResponse(DefaultPayload.create("test")), + FrameType.REQUEST_RESPONSE), + Arguments.of( + (Function>) + rSocket -> rSocket.requestStream(DefaultPayload.create("test")), + FrameType.REQUEST_STREAM), + Arguments.of( + (Function>) + rSocket -> rSocket.requestChannel(Mono.just(DefaultPayload.create("test"))), + FrameType.REQUEST_CHANNEL), + Arguments.of( + (Function>) + rSocket -> + rSocket.metadataPush( + DefaultPayload.create(new byte[0], "test".getBytes(CharsetUtil.UTF_8))), + FrameType.METADATA_PUSH)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java new file mode 100644 index 000000000..5cfa76a1c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTerminationTest.java @@ -0,0 +1,113 @@ +package io.rsocket.core; + +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketRequesterTest.ClientSocketRule; +import io.rsocket.frame.FrameType; +import io.rsocket.util.EmptyPayload; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.Arrays; +import java.util.function.Function; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +public class RSocketRequesterTerminationTest { + + public final ClientSocketRule rule = new ClientSocketRule(); + + @BeforeEach + public void setup() { + rule.init(); + } + + @AfterEach + public void tearDownAndCheckNoLeaks() { + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("rsocketInteractions") + public void testCurrentStreamIsTerminatedOnConnectionClose( + FrameType requestType, Function> interaction) { + RSocketRequester rSocket = rule.socket; + + StepVerifier.create(interaction.apply(rSocket)) + .then( + () -> { + FrameAssert.assertThat(rule.connection.pollFrame()).typeOf(requestType).hasNoLeaks(); + }) + .then(() -> rule.connection.dispose()) + .expectError(ClosedChannelException.class) + .verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("rsocketInteractions") + public void testSubsequentStreamIsTerminatedAfterConnectionClose( + FrameType requestType, Function> interaction) { + RSocketRequester rSocket = rule.socket; + + rule.connection.dispose(); + StepVerifier.create(interaction.apply(rSocket)) + .expectError(ClosedChannelException.class) + .verify(Duration.ofSeconds(5)); + } + + public static Iterable rsocketInteractions() { + EmptyPayload payload = EmptyPayload.INSTANCE; + + Arguments resp = + Arguments.of( + FrameType.REQUEST_RESPONSE, + new Function>() { + @Override + public Mono apply(RSocket rSocket) { + return rSocket.requestResponse(payload); + } + + @Override + public String toString() { + return "Request Response"; + } + }); + Arguments stream = + Arguments.of( + FrameType.REQUEST_STREAM, + new Function>() { + @Override + public Flux apply(RSocket rSocket) { + return rSocket.requestStream(payload); + } + + @Override + public String toString() { + return "Request Stream"; + } + }); + Arguments channel = + Arguments.of( + FrameType.REQUEST_CHANNEL, + new Function>() { + @Override + public Flux apply(RSocket rSocket) { + return rSocket.requestChannel(Flux.never().startWith(payload)); + } + + @Override + public String toString() { + return "Request Channel"; + } + }); + + return Arrays.asList(resp, stream, channel); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java new file mode 100644 index 000000000..a1199f698 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketRequesterTest.java @@ -0,0 +1,1516 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.ReassemblyUtils.ILLEGAL_REASSEMBLED_PAYLOAD_SIZE; +import static io.rsocket.core.TestRequesterResponderSupport.fixedSizePayload; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.core.TestRequesterResponderSupport.prepareFragments; +import static io.rsocket.core.TestRequesterResponderSupport.randomMetadataOnlyPayload; +import static io.rsocket.core.TestRequesterResponderSupport.randomPayload; +import static io.rsocket.frame.FrameHeaderCodec.frameType; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.CANCEL; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Stream; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.Scannable; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketRequesterTest { + + ClientSocketRule rule; + + @BeforeEach + public void setUp() throws Throwable { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped((t) -> {}); + rule = new ClientSocketRule(); + rule.init(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testInvalidFrameOnStream0ShouldNotTerminateRSocket() { + rule.connection.addToReceivedBuffer(RequestNFrameCodec.encode(rule.alloc(), 0, 10)); + assertThat(rule.socket.isDisposed()).isFalse(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testStreamInitialN() { + Flux stream = rule.socket.requestStream(EmptyPayload.INSTANCE); + + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + // don't request here + } + }; + stream.subscribe(subscriber); + + assertThat(rule.connection.getSent()).isEmpty(); + + subscriber.request(5); + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat(sent.size()).describedAs("sent frame count").isEqualTo(1); + + ByteBuf f = sent.get(0); + + assertThat(frameType(f)).describedAs("initial frame").isEqualTo(REQUEST_STREAM); + assertThat(RequestStreamFrameCodec.initialRequestN(f)) + .describedAs("initial request n") + .isEqualTo(5L); + assertThat(f.release()).describedAs("should be released").isEqualTo(true); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleSetupException() { + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), 0, new RejectedSetupException("boom"))); + assertThatThrownBy(() -> rule.socket.onClose().block()) + .isInstanceOf(RejectedSetupException.class); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleApplicationException() { + rule.connection.clearSendReceiveBuffers(); + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(); + response.subscribe(responseSub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), streamId, new ApplicationErrorException("error"))); + + verify(responseSub).onError(any(ApplicationErrorException.class)); + + assertThat(rule.connection.getSent()) + // requestResponseFrame + .hasSize(1) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testHandleValidFrame() { + Publisher response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber sub = TestSubscriber.create(); + response.subscribe(sub); + + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encodeNextReleasingPayload( + rule.alloc(), streamId, EmptyPayload.INSTANCE)); + + verify(sub).onComplete(); + assertThat(rule.connection.getSent()).hasSize(1).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testRequestReplyWithCancel() { + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + + try { + response.block(Duration.ofMillis(100)); + } catch (IllegalStateException ise) { + } + + List sent = new ArrayList<>(rule.connection.getSent()); + + assertThat(frameType(sent.get(0))) + .describedAs("Unexpected frame sent on the connection.") + .isEqualTo(REQUEST_RESPONSE); + assertThat(frameType(sent.get(1))) + .describedAs("Unexpected frame sent on the connection.") + .isEqualTo(CANCEL); + assertThat(sent).hasSize(2).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + @Disabled("invalid") + @Timeout(2_000) + public void testRequestReplyErrorOnSend() { + rule.connection.setAvailability(0); // Fails send + Mono response = rule.socket.requestResponse(EmptyPayload.INSTANCE); + Subscriber responseSub = TestSubscriber.create(10); + response.subscribe(responseSub); + + this.rule + .socket + .onClose() + .as(StepVerifier::create) + .expectComplete() + .verify(Duration.ofMillis(100)); + + verify(responseSub).onSubscribe(any(Subscription.class)); + + rule.assertHasNoLeaks(); + // TODO this should get the error reported through the response subscription + // verify(responseSub).onError(any(RuntimeException.class)); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation() { + Sinks.Empty cancelled = Sinks.empty(); + Flux request = Flux.never().doOnCancel(cancelled::tryEmitEmpty); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.firstWithSignal( + cancelled.asMono(), + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + public void testChannelRequestCancellation2() { + Sinks.Empty cancelled = Sinks.empty(); + Flux request = + Flux.just(EmptyPayload.INSTANCE).repeat(259).doOnCancel(cancelled::tryEmitEmpty); + rule.socket.requestChannel(request).subscribe().dispose(); + Flux.firstWithSignal( + cancelled.asMono(), + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testChannelRequestServerSideCancellation() { + Sinks.One cancelled = Sinks.one(); + Sinks.Many request = Sinks.many().unicast().onBackpressureBuffer(); + request.tryEmitNext(EmptyPayload.INSTANCE); + rule.socket + .requestChannel(request.asFlux()) + .subscribe(cancelled::tryEmitValue, cancelled::tryEmitError, cancelled::tryEmitEmpty); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(rule.alloc(), streamId)); + rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.alloc(), streamId)); + Flux.firstWithSignal( + cancelled.asMono(), + Flux.error(new IllegalStateException("Channel request not cancelled")) + .delaySubscription(Duration.ofSeconds(1))) + .blockFirst(); + + assertThat(request.scan(Scannable.Attr.TERMINATED) || request.scan(Scannable.Attr.CANCELLED)) + .isTrue(); + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_CHANNEL) + .matches(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @Test + public void testCorrectFrameOrder() { + Sinks.One delayer = Sinks.one(); + BaseSubscriber subscriber = + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) {} + }; + rule.socket + .requestChannel( + Flux.concat(Flux.just(0).delayUntil(i -> delayer.asMono()), Flux.range(1, 999)) + .map(i -> DefaultPayload.create(i + ""))) + .subscribe(subscriber); + + subscriber.request(1); + subscriber.request(Long.MAX_VALUE); + delayer.tryEmitEmpty(); + + Iterator iterator = rule.connection.getSent().iterator(); + + ByteBuf initialFrame = iterator.next(); + + assertThat(FrameHeaderCodec.frameType(initialFrame)).isEqualTo(REQUEST_CHANNEL); + assertThat(RequestChannelFrameCodec.initialRequestN(initialFrame)).isEqualTo(Long.MAX_VALUE); + assertThat(RequestChannelFrameCodec.data(initialFrame).toString(CharsetUtil.UTF_8)) + .isEqualTo("0"); + assertThat(initialFrame.release()).isTrue(); + + assertThat(iterator.hasNext()).isFalse(); + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(ints = {128, 256, FRAME_LENGTH_MASK}) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + generator.apply(rule.socket, DefaultPayload.create(data, metadata))) + .expectSubscription() + .expectErrorSatisfies( + t -> + assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) + .verify(); + rule.assertHasNoLeaks(); + }); + } + + @ParameterizedTest + @ValueSource(ints = {128, 256, FRAME_LENGTH_MASK}) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation1( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + prepareCalls() + .forEach( + generator -> { + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + assertThatThrownBy( + () -> { + final Publisher source = + generator.apply(rule.socket, DefaultPayload.create(data, metadata)); + + if (source instanceof Mono) { + ((Mono) source).block(); + } else { + ((Flux) source).blockLast(); + } + }) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength)); + + rule.assertHasNoLeaks(); + }); + } + + @Test + public void shouldRejectCallOfNoMetadataPayload() { + final ByteBuf data = rule.allocator.buffer(10); + final Payload payload = ByteBufPayload.create(data); + StepVerifier.create(rule.socket.metadataPush(payload)) + .expectSubscription() + .expectErrorSatisfies( + t -> + assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Metadata push should have metadata field present")) + .verify(); + PayloadAssert.assertThat(payload).isReleased(); + rule.assertHasNoLeaks(); + } + + @Test + public void shouldRejectCallOfNoMetadataPayloadBlocking() { + final ByteBuf data = rule.allocator.buffer(10); + final Payload payload = ByteBufPayload.create(data); + + assertThatThrownBy(() -> rule.socket.metadataPush(payload).block()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Metadata push should have metadata field present"); + PayloadAssert.assertThat(payload).isReleased(); + rule.assertHasNoLeaks(); + } + + static Stream>> prepareCalls() { + return Stream.of( + RSocket::fireAndForget, + RSocket::requestResponse, + RSocket::requestStream, + (rSocket, payload) -> rSocket.requestChannel(Flux.just(payload)), + RSocket::metadataPush); + } + + @ParameterizedTest + @ValueSource(ints = {128, 256, FrameLengthCodec.FRAME_LENGTH_MASK}) + public void + shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentationForRequestChannelCase( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + StepVerifier.create( + rule.socket.requestChannel( + Flux.just(EmptyPayload.INSTANCE, DefaultPayload.create(data, metadata))), + 0) + .expectSubscription() + .thenRequest(2) + .then( + () -> { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + rule.alloc(), rule.getStreamIdForRequestType(REQUEST_CHANNEL), 2)); + }) + .expectErrorSatisfies( + t -> + assertThat(t) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) + .verify(); + assertThat(rule.connection.getSent()) + // expect to be sent RequestChannelFrame + // expect to be sent CancelFrame + .hasSize(2) + .allMatch(ReferenceCounted::release); + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("racingCases") + public void checkNoLeaksOnRacing( + Function> initiator, + BiConsumer, ClientSocketRule> runner) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ClientSocketRule clientSocketRule = new ClientSocketRule(); + + clientSocketRule.init(); + + Publisher payloadP = initiator.apply(clientSocketRule); + AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + if (payloadP instanceof Flux) { + ((Flux) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } else { + ((Mono) payloadP).doOnNext(Payload::release).subscribe(assertSubscriber); + } + + runner.accept(assertSubscriber, clientSocketRule); + + assertThat(clientSocketRule.connection.getSent()).allMatch(ReferenceCounted::release); + + clientSocketRule.assertHasNoLeaks(); + } + } + + private static Stream racingCases() { + return Stream.of( + Arguments.of( + (Function>) + (rule) -> rule.socket.requestStream(EmptyPayload.INSTANCE), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> rule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE)), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + + return rule.socket.requestStream(payload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + if (rule.connection.getSent().size() > 0) { + assertThat(rule.connection.getSent()).hasSize(2); + assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_STREAM, + "Expected first frame matches {" + + REQUEST_STREAM + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected first frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBufAllocator allocator = rule.alloc(); + return rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("metadata", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("data", CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + sink.complete(); + return ++index; + })); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + RaceTestUtils.race(() -> as.request(1), as::cancel); + // ensures proper frames order + int size = rule.connection.getSent().size(); + if (size > 0) { + + assertThat(size).isLessThanOrEqualTo(3).isGreaterThanOrEqualTo(2); + assertThat(rule.connection.getSent()) + .element(0) + .matches( + bb -> frameType(bb) == REQUEST_CHANNEL, + "Expected first frame matches {" + + REQUEST_CHANNEL + + "} but was {" + + frameType(rule.connection.getSent().stream().findFirst().get()) + + "}"); + if (size == 2) { + assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == CANCEL, + "Expected second frame matches {" + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + } else { + assertThat(rule.connection.getSent()) + .element(1) + .matches( + bb -> frameType(bb) == COMPLETE || frameType(bb) == CANCEL, + "Expected second frame matches {" + + COMPLETE + + " or " + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(1).findFirst().get()) + + "}"); + assertThat(rule.connection.getSent()) + .element(2) + .matches( + bb -> frameType(bb) == CANCEL || frameType(bb) == COMPLETE, + "Expected third frame matches {" + + COMPLETE + + " or " + + CANCEL + + "} but was {" + + frameType( + rule.connection.getSent().stream().skip(2).findFirst().get()) + + "}"); + } + } + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = CancelFrameCodec.encode(allocator, streamId); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> + rule.socket.requestChannel( + Flux.generate( + () -> 1L, + (index, sink) -> { + ByteBuf data = rule.alloc().buffer(); + data.writeCharSequence("d" + index, CharsetUtil.UTF_8); + ByteBuf metadata = rule.alloc().buffer(); + metadata.writeCharSequence("m" + index, CharsetUtil.UTF_8); + final Payload payload = ByteBufPayload.create(data, metadata); + sink.next(payload); + return ++index; + })), + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + as.request(1); + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + ByteBuf frame = + ErrorFrameCodec.encode(allocator, streamId, new RuntimeException("test")); + + RaceTestUtils.race( + () -> as.request(Long.MAX_VALUE), + () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBuf data = rule.allocator.buffer(); + data.writeCharSequence("testData", CharsetUtil.UTF_8); + + ByteBuf metadata = rule.allocator.buffer(); + metadata.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + Payload requestPayload = ByteBufPayload.create(data, metadata); + return rule.socket.requestResponse(requestPayload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(Long.MAX_VALUE); + int streamId = rule.getStreamIdForRequestType(REQUEST_RESPONSE); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, false, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + }), + Arguments.of( + (Function>) + (rule) -> { + ByteBuf data = rule.allocator.buffer(); + data.writeCharSequence("testData", CharsetUtil.UTF_8); + + ByteBuf metadata = rule.allocator.buffer(); + metadata.writeCharSequence("testMetadata", CharsetUtil.UTF_8); + Payload requestPayload = ByteBufPayload.create(data, metadata); + return rule.socket.requestStream(requestPayload); + }, + (BiConsumer, ClientSocketRule>) + (as, rule) -> { + ByteBufAllocator allocator = rule.alloc(); + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence("abc", CharsetUtil.UTF_8); + ByteBuf data = allocator.buffer(); + data.writeCharSequence("def", CharsetUtil.UTF_8); + as.request(Long.MAX_VALUE); + int streamId = rule.getStreamIdForRequestType(REQUEST_STREAM); + ByteBuf frame = + PayloadFrameCodec.encode( + allocator, streamId, false, true, true, metadata, data); + + RaceTestUtils.race(as::cancel, () -> rule.connection.addToReceivedBuffer(frame)); + })); + } + + @Test + public void simpleOnDiscardRequestChannelTest() { + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + Sinks.Many testPublisher = Sinks.many().unicast().onBackpressureBuffer(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher.asFlux()); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d"), ByteBufUtil.writeUtf8(rule.alloc(), "m"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d1"), ByteBufUtil.writeUtf8(rule.alloc(), "m1"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d2"), ByteBufUtil.writeUtf8(rule.alloc(), "m2"))); + + assertSubscriber.cancel(); + + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleOnDiscardRequestChannelTest2() { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + Sinks.Many testPublisher = Sinks.many().unicast().onBackpressureBuffer(); + + Flux payloadFlux = rule.socket.requestChannel(testPublisher.asFlux()); + + payloadFlux.subscribe(assertSubscriber); + + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d"), ByteBufUtil.writeUtf8(rule.alloc(), "m"))); + + int streamId = rule.getStreamIdForRequestType(REQUEST_CHANNEL); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d1"), ByteBufUtil.writeUtf8(rule.alloc(), "m1"))); + testPublisher.tryEmitNext( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "d2"), ByteBufUtil.writeUtf8(rule.alloc(), "m2"))); + + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode( + allocator, streamId, new CustomRSocketException(0x00000404, "test"))); + + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(responsesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + Publisher response; + + switch (frameType) { + case REQUEST_FNF: + response = + testPublisher.mono().flatMap(p -> rule.socket.fireAndForget(p)).then(Mono.empty()); + break; + case REQUEST_RESPONSE: + response = testPublisher.mono().flatMap(p -> rule.socket.requestResponse(p)); + break; + case REQUEST_STREAM: + response = testPublisher.mono().flatMapMany(p -> rule.socket.requestStream(p)); + break; + case REQUEST_CHANNEL: + response = rule.socket.requestChannel(testPublisher.flux()); + break; + default: + throw new UnsupportedOperationException("illegal case"); + } + + response.subscribe(assertSubscriber); + testPublisher.next(ByteBufPayload.create(ByteBufUtil.writeUtf8(rule.alloc(), "d"))); + + int streamId = rule.getStreamIdForRequestType(frameType); + + if (responsesCnt > 0) { + for (int i = 0; i < responsesCnt - 1; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + streamId, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("rd" + (i + 1)).getBytes()))); + } + + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + streamId, + false, + true, + true, + null, + Unpooled.wrappedBuffer(("rd" + responsesCnt).getBytes()))); + } + + if (framesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode(allocator, streamId, framesCnt)); + } + + for (int i = 1; i < framesCnt; i++) { + testPublisher.next(ByteBufPayload.create(ByteBufUtil.writeUtf8(rule.alloc(), "d" + i))); + } + + assertThat(rule.connection.getSent()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, framesCnt) + .hasSize(framesCnt) + .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)) + .allMatch(ByteBuf::release); + + assertThat(assertSubscriber.isTerminated()) + .describedAs("Interaction Type :[%s]. Expected to be terminated", frameType) + .isTrue(); + + assertThat(assertSubscriber.values()) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames received", + frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(p -> p.release()); + + rule.assertHasNoLeaks(); + rule.connection.clearSendReceiveBuffers(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload( + BiFunction> sourceProducer) { + Payload invalidPayload = + ByteBufPayload.create( + ByteBufUtil.writeUtf8(rule.alloc(), "test"), + ByteBufUtil.writeUtf8(rule.alloc(), "test")); + invalidPayload.release(); + + Publisher source = sourceProducer.apply(invalidPayload, rule); + + StepVerifier.create(source, 1) + .expectError(IllegalReferenceCountException.class) + .verify(Duration.ofMillis(1000)); + } + + private static Stream>> refCntCases() { + return Stream.of( + (p, clientSocketRule) -> clientSocketRule.socket.fireAndForget(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestResponse(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestStream(p), + (p, clientSocketRule) -> clientSocketRule.socket.requestChannel(Mono.just(p)), + (p, clientSocketRule) -> { + Flux.from(clientSocketRule.connection.getSentAsPublisher()) + .filter(bb -> frameType(bb) == REQUEST_CHANNEL) + .doOnDiscard(ByteBuf.class, ReferenceCounted::release) + .subscribe( + bb -> { + clientSocketRule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode( + clientSocketRule.allocator, FrameHeaderCodec.streamId(bb), 1)); + bb.release(); + }); + + return clientSocketRule.socket.requestChannel(Flux.just(EmptyPayload.INSTANCE, p)); + }); + } + + @Test + public void ensuresThatNoOpsMustHappenUntilSubscriptionInCaseOfFnfCall() { + Payload payload1 = ByteBufPayload.create("abc1"); + Mono fnf1 = rule.socket.fireAndForget(payload1); + + Payload payload2 = ByteBufPayload.create("abc2"); + Mono fnf2 = rule.socket.fireAndForget(payload2); + + assertThat(rule.connection.getSent()).isEmpty(); + + // checks that fnf2 should have id 1 even though it was generated later than fnf1 + AssertSubscriber voidAssertSubscriber2 = fnf2.subscribeWith(AssertSubscriber.create(0)); + voidAssertSubscriber2.assertTerminated().assertNoError(); + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_FNF) + .matches(bb -> FrameHeaderCodec.streamId(bb) == 1) + // ensures that this is fnf1 with abc2 data + .matches( + bb -> + ByteBufUtil.equals( + RequestFireAndForgetFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc2".getBytes()))) + .matches(ReferenceCounted::release); + + rule.connection.clearSendReceiveBuffers(); + + // checks that fnf1 should have id 3 even though it was generated earlier + AssertSubscriber voidAssertSubscriber1 = fnf1.subscribeWith(AssertSubscriber.create(0)); + voidAssertSubscriber1.assertTerminated().assertNoError(); + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> frameType(bb) == REQUEST_FNF) + .matches(bb -> FrameHeaderCodec.streamId(bb) == 3) + // ensures that this is fnf1 with abc1 data + .matches( + bb -> + ByteBufUtil.equals( + RequestFireAndForgetFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc1".getBytes()))) + .matches(ReferenceCounted::release); + } + + @ParameterizedTest + @MethodSource("requestNInteractions") + public void ensuresThatNoOpsMustHappenUntilFirstRequestN( + FrameType frameType, BiFunction> interaction) { + Payload payload1 = ByteBufPayload.create("abc1"); + Publisher interaction1 = interaction.apply(rule, payload1); + + Payload payload2 = ByteBufPayload.create("abc2"); + Publisher interaction2 = interaction.apply(rule, payload2); + + assertThat(rule.connection.getSent()).isEmpty(); + + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(0); + interaction1.subscribe(assertSubscriber1); + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(0); + interaction2.subscribe(assertSubscriber2); + assertSubscriber1.assertNotTerminated().assertNoError(); + assertSubscriber2.assertNotTerminated().assertNoError(); + // even though we subscribed, nothing should happen until the first requestN + assertThat(rule.connection.getSent()).isEmpty(); + + // first request on the second interaction to ensure that stream id issuing on the first request + assertSubscriber2.request(1); + + assertThat(rule.connection.getSent()) + .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) + .first() + .matches(bb -> frameType(bb) == frameType) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 1, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next()) + + "}") + .matches( + bb -> { + switch (frameType) { + case REQUEST_RESPONSE: + return ByteBufUtil.equals( + RequestResponseFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc2".getBytes())); + case REQUEST_STREAM: + return ByteBufUtil.equals( + RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes())); + case REQUEST_CHANNEL: + return ByteBufUtil.equals( + RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc2".getBytes())); + } + + return false; + }) + .matches(ReferenceCounted::release); + + if (frameType == REQUEST_CHANNEL) { + assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> frameType(bb) == COMPLETE) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 1, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(new ArrayList<>(rule.connection.getSent()).get(1)) + + "}") + .matches(ReferenceCounted::release); + } + + rule.connection.clearSendReceiveBuffers(); + + assertSubscriber1.request(1); + assertThat(rule.connection.getSent()) + .hasSize(frameType == REQUEST_CHANNEL ? 2 : 1) + .first() + .matches(bb -> frameType(bb) == frameType) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 3, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(rule.connection.getSent().iterator().next()) + + "}") + .matches( + bb -> { + switch (frameType) { + case REQUEST_RESPONSE: + return ByteBufUtil.equals( + RequestResponseFrameCodec.data(bb), + Unpooled.wrappedBuffer("abc1".getBytes())); + case REQUEST_STREAM: + return ByteBufUtil.equals( + RequestStreamFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes())); + case REQUEST_CHANNEL: + return ByteBufUtil.equals( + RequestChannelFrameCodec.data(bb), Unpooled.wrappedBuffer("abc1".getBytes())); + } + + return false; + }) + .matches(ReferenceCounted::release); + + if (frameType == REQUEST_CHANNEL) { + assertThat(rule.connection.getSent()) + .element(1) + .matches(bb -> frameType(bb) == COMPLETE) + .matches( + bb -> FrameHeaderCodec.streamId(bb) == 3, + "Expected to have stream ID {1} but got {" + + FrameHeaderCodec.streamId(new ArrayList<>(rule.connection.getSent()).get(1)) + + "}") + .matches(ReferenceCounted::release); + } + } + + private static Stream requestNInteractions() { + return Stream.of( + Arguments.of( + REQUEST_RESPONSE, + (BiFunction>) + (rule, payload) -> rule.socket.requestResponse(payload)), + Arguments.of( + REQUEST_STREAM, + (BiFunction>) + (rule, payload) -> rule.socket.requestStream(payload)), + Arguments.of( + REQUEST_CHANNEL, + (BiFunction>) + (rule, payload) -> rule.socket.requestChannel(Flux.just(payload)))); + } + + @ParameterizedTest + @MethodSource("streamRacingCases") + @Disabled("Connection should take care of ordering if such is necessary") + public void ensuresCorrectOrderOfStreamIdIssuingInCaseOfRacing( + BiFunction> interaction1, + BiFunction> interaction2, + FrameType interactionType1, + FrameType interactionType2) { + Assumptions.assumeThat(interactionType1).isNotEqualTo(METADATA_PUSH); + Assumptions.assumeThat(interactionType2).isNotEqualTo(METADATA_PUSH); + for (int i = 1; i < RaceTestConstants.REPEATS; i += 4) { + Payload payload = DefaultPayload.create("test", "test"); + Publisher publisher1 = interaction1.apply(rule, payload); + Publisher publisher2 = interaction2.apply(rule, payload); + RaceTestUtils.race( + () -> publisher1.subscribe(AssertSubscriber.create()), + () -> publisher2.subscribe(AssertSubscriber.create())); + + assertThat(rule.connection.getSent()) + .extracting(FrameHeaderCodec::streamId) + .containsExactly(i, i + 2); + rule.connection.getSent().forEach(bb -> bb.release()); + rule.connection.getSent().clear(); + } + } + + public static Stream streamRacingCases() { + return Stream.of( + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p), + (BiFunction>) + (r, p) -> r.socket.requestResponse(p), + REQUEST_FNF, + REQUEST_RESPONSE), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestResponse(p), + (BiFunction>) + (r, p) -> r.socket.requestStream(p), + REQUEST_RESPONSE, + REQUEST_STREAM), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.requestStream(p), + (BiFunction>) + (r, p) -> { + AtomicBoolean subscribed = new AtomicBoolean(); + Flux just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true)); + return r.socket + .requestChannel(just) + .doFinally( + __ -> { + if (!subscribed.get()) { + p.release(); + } + }); + }, + REQUEST_STREAM, + REQUEST_CHANNEL), + Arguments.of( + (BiFunction>) + (r, p) -> { + AtomicBoolean subscribed = new AtomicBoolean(); + Flux just = Flux.just(p).doOnSubscribe((__) -> subscribed.set(true)); + return r.socket + .requestChannel(just) + .doFinally( + __ -> { + if (!subscribed.get()) { + p.release(); + } + }); + }, + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p), + REQUEST_CHANNEL, + REQUEST_FNF), + Arguments.of( + (BiFunction>) + (r, p) -> r.socket.metadataPush(p), + (BiFunction>) + (r, p) -> r.socket.fireAndForget(p), + METADATA_PUSH, + REQUEST_FNF)); + } + + @ParameterizedTest + @MethodSource("streamRacingCases") + @SuppressWarnings({"rawtypes", "unchecked"}) + public void shouldTerminateAllStreamsIfThereRacingBetweenDisposeAndRequests( + BiFunction> interaction1, + BiFunction> interaction2, + FrameType interactionType1, + FrameType interactionType2) { + for (int i = 1; i < RaceTestConstants.REPEATS; i++) { + Payload payload1 = ByteBufPayload.create("test", "test"); + Payload payload2 = ByteBufPayload.create("test", "test"); + AssertSubscriber assertSubscriber1 = AssertSubscriber.create(); + AssertSubscriber assertSubscriber2 = AssertSubscriber.create(); + Publisher publisher1 = interaction1.apply(rule, payload1); + Publisher publisher2 = interaction2.apply(rule, payload2); + RaceTestUtils.race( + () -> rule.socket.dispose(), + () -> publisher1.subscribe(assertSubscriber1), + () -> publisher2.subscribe(assertSubscriber2)); + + assertSubscriber1.await().assertTerminated(); + if (interactionType1 != REQUEST_FNF && interactionType1 != METADATA_PUSH) { + assertSubscriber1.assertError(ClosedChannelException.class); + } else { + try { + assertSubscriber1.assertError(ClosedChannelException.class); + } catch (Throwable t) { + // fnf call may be completed + assertSubscriber1.assertComplete(); + } + } + assertSubscriber2.await().assertTerminated(); + if (interactionType2 != REQUEST_FNF && interactionType2 != METADATA_PUSH) { + assertSubscriber2.assertError(ClosedChannelException.class); + } else { + try { + assertSubscriber2.assertError(ClosedChannelException.class); + } catch (Throwable t) { + // fnf call may be completed + assertSubscriber2.assertComplete(); + } + } + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + rule.connection.getSent().clear(); + + assertThat(payload1.refCnt()).isZero(); + assertThat(payload2.refCnt()).isZero(); + } + } + + @Test + // see https://github.com/rsocket/rsocket-java/issues/858 + public void testWorkaround858() { + ByteBuf buffer = rule.alloc().buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + + rule.socket.requestResponse(ByteBufPayload.create(buffer)).subscribe(); + + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), 1, new RuntimeException("test"))); + + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_RESPONSE) + .matches(ByteBuf::release); + + assertThat(rule.socket.isDisposed()).isFalse(); + + rule.assertHasNoLeaks(); + } + + @DisplayName("reassembles data") + @ParameterizedTest + @MethodSource("requestNInteractions") + void reassembleData( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload randomPayload = randomPayload(leaksTrackingByteBufAllocator); + List fragments = prepareFragments(leaksTrackingByteBufAllocator, mtu, randomPayload); + + final Publisher responsePublisher = requestFunction.apply(rule, requestPayload); + StepVerifier.create(responsePublisher) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .assertNext( + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .thenCancel() + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + if (!rule.connection.getSent().isEmpty()) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + } + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @DisplayName("reassembles metadata") + @ParameterizedTest + @MethodSource("requestNInteractions") + void reassembleMetadata( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload metadataOnlyPayload = randomMetadataOnlyPayload(leaksTrackingByteBufAllocator); + List fragments = + prepareFragments(leaksTrackingByteBufAllocator, mtu, metadataOnlyPayload); + + StepVerifier.create(requestFunction.apply(rule, requestPayload)) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .assertNext( + responsePayload -> { + PayloadAssert.assertThat(responsePayload).isEqualTo(metadataOnlyPayload).hasNoLeaks(); + metadataOnlyPayload.release(); + }) + .thenCancel() + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + if (!rule.connection.getSent().isEmpty()) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + } + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if reassembling payload size exceeds {0}") + @MethodSource("requestNInteractions") + public void errorTooBigPayload( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final int maxInboundPayloadSize = ThreadLocalRandom.current().nextInt(mtu + 1, 4096); + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload responsePayload = + fixedSizePayload(leaksTrackingByteBufAllocator, maxInboundPayloadSize + 1); + List fragments = prepareFragments(leaksTrackingByteBufAllocator, mtu, responsePayload); + responsePayload.release(); + + rule.setMaxInboundPayloadSize(maxInboundPayloadSize); + + StepVerifier.create(requestFunction.apply(rule, requestPayload)) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .expectErrorMessage(String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)) + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if fragment before the last is < min MTU {0}") + @MethodSource("requestNInteractions") + public void errorFragmentTooSmall( + FrameType frameType, + BiFunction> requestFunction) { + final int mtu = 32; + final LeaksTrackingByteBufAllocator leaksTrackingByteBufAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + final Payload requestPayload = genericPayload(leaksTrackingByteBufAllocator); + final Payload responsePayload = fixedSizePayload(leaksTrackingByteBufAllocator, 156); + List fragments = prepareFragments(leaksTrackingByteBufAllocator, mtu, responsePayload); + responsePayload.release(); + + StepVerifier.create(requestFunction.apply(rule, requestPayload)) + .then(() -> rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0]))) + .expectErrorMessage("Fragment is too small.") + .verify(); + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(frameType).hasNoLeaks(); + + if (frameType == REQUEST_CHANNEL) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(CANCEL).hasNoLeaks(); + + leaksTrackingByteBufAllocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(strings = {"stream", "channel"}) + // see https://github.com/rsocket/rsocket-java/issues/959 + public void testWorkaround959(String type) { + for (int i = 1; i < 20000; i += 2) { + ByteBuf buffer = rule.alloc().buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(3); + if (type.equals("stream")) { + rule.socket.requestStream(ByteBufPayload.create(buffer)).subscribe(assertSubscriber); + } else if (type.equals("channel")) { + rule.socket + .requestChannel(Flux.just(ByteBufPayload.create(buffer))) + .subscribe(assertSubscriber); + } + + final ByteBuf payloadFrame = + PayloadFrameCodec.encode( + rule.alloc(), i, false, false, true, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + + RaceTestUtils.race( + () -> { + rule.connection.addToReceivedBuffer(payloadFrame.copy()); + rule.connection.addToReceivedBuffer(payloadFrame.copy()); + rule.connection.addToReceivedBuffer(payloadFrame); + }, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertThat(rule.connection.getSent()).allMatch(ByteBuf::release); + + assertThat(rule.socket.isDisposed()).isFalse(); + + assertSubscriber.values().forEach(ReferenceCountUtil::safeRelease); + assertSubscriber.assertNoError(); + + rule.connection.clearSendReceiveBuffers(); + rule.assertHasNoLeaks(); + } + } + + public static class ClientSocketRule extends AbstractSocketRule { + + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + @Override + protected RSocketRequester newRSocket() { + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); + return new RSocketRequester( + connection, + PayloadDecoder.ZERO_COPY, + StreamIdSupplier.clientSupplier(), + 0, + maxFrameLength, + maxInboundPayloadSize, + Integer.MAX_VALUE, + Integer.MAX_VALUE, + null, + (__) -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); + } + + public int getStreamIdForRequestType(FrameType expectedFrameType) { + assertThat(connection.getSent().size()) + .describedAs("Unexpected frames sent.") + .isGreaterThanOrEqualTo(1); + List framesFound = new ArrayList<>(); + for (ByteBuf frame : connection.getSent()) { + FrameType frameType = frameType(frame); + if (frameType == expectedFrameType) { + return FrameHeaderCodec.streamId(frame); + } + framesFound.add(frameType); + } + throw new AssertionError( + "No frames sent with frame type: " + + expectedFrameType + + ", frames found: " + + framesFound); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java new file mode 100644 index 000000000..4f689e396 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketResponderTest.java @@ -0,0 +1,1269 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.ReassemblyUtils.ILLEGAL_REASSEMBLED_PAYLOAD_SIZE; +import static io.rsocket.core.TestRequesterResponderSupport.fixedSizePayload; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.core.TestRequesterResponderSupport.prepareFragments; +import static io.rsocket.core.TestRequesterResponderSupport.randomMetadataOnlyPayload; +import static io.rsocket.core.TestRequesterResponderSupport.randomPayload; +import static io.rsocket.frame.FrameHeaderCodec.frameType; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.ERROR; +import static io.rsocket.frame.FrameType.NEXT; +import static io.rsocket.frame.FrameType.NEXT_COMPLETE; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_N; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.frame.CancelFrameCodec; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.PayloadFrameCodec; +import io.rsocket.frame.RequestChannelFrameCodec; +import io.rsocket.frame.RequestFireAndForgetFrameCodec; +import io.rsocket.frame.RequestNFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestSubscriber; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RSocketResponderTest { + + ServerSocketRule rule; + + @BeforeEach + public void setUp() { + Hooks.onNextDropped(ReferenceCountUtil::safeRelease); + Hooks.onErrorDropped(t -> {}); + rule = new ServerSocketRule(); + rule.init(); + } + + @AfterEach + public void tearDown() { + Hooks.resetOnErrorDropped(); + Hooks.resetOnNextDropped(); + rule.assertHasNoLeaks(); + } + + @Test + @Timeout(2_000) + @Disabled + public void testHandleKeepAlive() { + rule.connection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(rule.alloc(), true, 0, Unpooled.EMPTY_BUFFER)); + ByteBuf sent = rule.connection.awaitFrame(); + assertThat(frameType(sent)) + .describedAs("Unexpected frame sent.") + .isEqualTo(FrameType.KEEPALIVE); + /*Keep alive ack must not have respond flag else, it will result in infinite ping-pong of keep alive frames.*/ + assertThat(KeepAliveFrameCodec.respondFlag(sent)) + .describedAs("Unexpected keep-alive frame respond flag.") + .isEqualTo(false); + } + + @Test + @Timeout(2_000) + public void testHandleResponseFrameNoError() { + final int streamId = 4; + rule.connection.clearSendReceiveBuffers(); + final TestPublisher testPublisher = TestPublisher.create(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return testPublisher.mono(); + } + }); + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + testPublisher.complete(); + FrameAssert.assertThat(rule.connection.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + testPublisher.assertWasNotCancelled(); + } + + @Test + @Timeout(2_000) + public void testHandlerEmitsError() { + final int streamId = 4; + rule.prefetch = 1; + rule.sendRequest(streamId, FrameType.REQUEST_STREAM); + FrameAssert.assertThat(rule.connection.awaitFrame()) + .typeOf(FrameType.ERROR) + .hasData("Request-Stream not implemented.") + .hasNoLeaks(); + } + + @Test + @Timeout(20_000) + public void testCancel() { + ByteBufAllocator allocator = rule.alloc(); + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return Mono.never().doOnCancel(() -> cancelled.set(true)); + } + }); + rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE); + + assertThat(rule.connection.getSent()).describedAs("Unexpected frame sent.").isEmpty(); + + rule.connection.addToReceivedBuffer(CancelFrameCodec.encode(allocator, streamId)); + + assertThat(rule.connection.getSent()).describedAs("Unexpected frame sent.").isEmpty(); + assertThat(cancelled.get()).describedAs("Subscription not cancelled.").isTrue(); + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(ints = {128, 256, FRAME_LENGTH_MASK}) + @Timeout(2_000) + public void shouldThrownExceptionIfGivenPayloadIsExitsSizeAllowanceWithNoFragmentation( + int maxFrameLength) { + rule.setMaxFrameLength(maxFrameLength); + final int streamId = 4; + final AtomicBoolean cancelled = new AtomicBoolean(); + byte[] metadata = new byte[maxFrameLength]; + byte[] data = new byte[maxFrameLength]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + final Payload payload = DefaultPayload.create(data, metadata); + final RSocket acceptingSocket = + new RSocket() { + @Override + public Mono requestResponse(Payload p) { + p.release(); + return Mono.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestStream(Payload p) { + p.release(); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .doOnNext(Payload::release) + .subscribe( + new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + subscription.request(1); + } + }); + return Flux.just(payload).doOnCancel(() -> cancelled.set(true)); + } + }; + rule.setAcceptingSocket(acceptingSocket); + + final Runnable[] runnables = { + () -> rule.sendRequest(streamId, FrameType.REQUEST_RESPONSE), + () -> rule.sendRequest(streamId, FrameType.REQUEST_STREAM), + () -> rule.sendRequest(streamId, FrameType.REQUEST_CHANNEL) + }; + + for (Runnable runnable : runnables) { + rule.connection.clearSendReceiveBuffers(); + runnable.run(); + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == FrameType.ERROR) + .matches( + bb -> + ErrorFrameCodec.dataUtf8(bb) + .contains(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, maxFrameLength))) + .matches(ReferenceCounted::release); + + assertThat(cancelled.get()).describedAs("Subscription not cancelled.").isTrue(); + } + + rule.assertHasNoLeaks(); + } + + @Test + public void checkNoLeaksOnRacingCancelFromRequestChannelAndNextFromUpstream() { + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + final Sinks.One sink = Sinks.one(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return sink.asMono().flux(); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, true, true, metadata3, data3); + + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), + () -> { + assertSubscriber.cancel(); + sink.tryEmitEmpty(); + }); + + assertThat(assertSubscriber.values()).allMatch(ReferenceCounted::release); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnComplete(1).expectNothing(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestChannelTest1() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + ((Flux) payloads) + .doOnNext(ReferenceCountUtil::safeRelease) + .subscribe(assertSubscriber); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + sink.complete(); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnCancel(1).expectNothing(); + rule.assertHasNoLeaks(); + } + } + + @Test + public void + checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromUpstreamOnErrorFromRequestChannelTest1() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + FluxSink[] sinks = new FluxSink[1]; + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + + return Flux.create( + sink -> { + sinks[0] = sink; + }, + FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_CHANNEL); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("def3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + + ByteBuf requestNFrame = RequestNFrameCodec.encode(allocator, 1, Integer.MAX_VALUE); + + ByteBuf m1 = allocator.buffer(); + m1.writeCharSequence("m1", CharsetUtil.UTF_8); + ByteBuf d1 = allocator.buffer(); + d1.writeCharSequence("d1", CharsetUtil.UTF_8); + Payload np1 = ByteBufPayload.create(d1, m1); + + ByteBuf m2 = allocator.buffer(); + m2.writeCharSequence("m2", CharsetUtil.UTF_8); + ByteBuf d2 = allocator.buffer(); + d2.writeCharSequence("d2", CharsetUtil.UTF_8); + Payload np2 = ByteBufPayload.create(d2, m2); + + ByteBuf m3 = allocator.buffer(); + m3.writeCharSequence("m3", CharsetUtil.UTF_8); + ByteBuf d3 = allocator.buffer(); + d3.writeCharSequence("d3", CharsetUtil.UTF_8); + Payload np3 = ByteBufPayload.create(d3, m3); + + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(requestNFrame), + () -> rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3), + () -> { + sink.next(np1); + sink.next(np2); + sink.next(np3); + sink.error(new RuntimeException()); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Outbound has terminated with an error"); + assertThat(assertSubscriber.values()) + .allMatch( + msg -> { + ReferenceCountUtil.safeRelease(msg); + return msg.refCnt() == 0; + }); + rule.assertHasNoLeaks(); + testRequestInterceptor.expectOnStart(1, REQUEST_CHANNEL).expectOnError(1).expectNothing(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestStreamTest1() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + + testRequestInterceptor.expectOnStart(1, REQUEST_STREAM).expectOnCancel(1).expectNothing(); + } + } + + @Test + public void checkNoLeaksOnRacingBetweenDownstreamCancelAndOnNextFromRequestResponseTest1() { + Hooks.onErrorDropped((e) -> {}); + ByteBufAllocator allocator = rule.alloc(); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + rule.setRequestInterceptor(testRequestInterceptor); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + Operators.MonoSubscriber[] sources = new Operators.MonoSubscriber[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + sources[0] = new Operators.MonoSubscriber<>(actual); + actual.onSubscribe(sources[0]); + } + }; + } + }, + Integer.MAX_VALUE); + + rule.sendRequest(1, REQUEST_RESPONSE); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + RaceTestUtils.race( + () -> rule.connection.addToReceivedBuffer(cancelFrame), + () -> { + sources[0].complete(ByteBufPayload.create("d1", "m1")); + }); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, REQUEST_RESPONSE) + .assertNext( + e -> + assertThat(e.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_CANCEL)) + .expectNothing(); + } + } + + @Test + public void simpleDiscardRequestStreamTest() { + ByteBufAllocator allocator = rule.alloc(); + FluxSink[] sinks = new FluxSink[1]; + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.create(sink -> sinks[0] = sink, FluxSink.OverflowStrategy.IGNORE); + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + FluxSink sink = sinks[0]; + + sink.next(ByteBufPayload.create("d1", "m1")); + sink.next(ByteBufPayload.create("d2", "m2")); + sink.next(ByteBufPayload.create("d3", "m3")); + rule.connection.addToReceivedBuffer(cancelFrame); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @Test + public void simpleDiscardRequestChannelTest() { + ByteBufAllocator allocator = rule.alloc(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return (Flux) payloads; + } + }, + 1); + + rule.sendRequest(1, REQUEST_STREAM); + + ByteBuf cancelFrame = CancelFrameCodec.encode(allocator, 1); + + ByteBuf metadata1 = allocator.buffer(); + metadata1.writeCharSequence("abc1", CharsetUtil.UTF_8); + ByteBuf data1 = allocator.buffer(); + data1.writeCharSequence("def1", CharsetUtil.UTF_8); + ByteBuf nextFrame1 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata1, data1); + + ByteBuf metadata2 = allocator.buffer(); + metadata2.writeCharSequence("abc2", CharsetUtil.UTF_8); + ByteBuf data2 = allocator.buffer(); + data2.writeCharSequence("def2", CharsetUtil.UTF_8); + ByteBuf nextFrame2 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata2, data2); + + ByteBuf metadata3 = allocator.buffer(); + metadata3.writeCharSequence("abc3", CharsetUtil.UTF_8); + ByteBuf data3 = allocator.buffer(); + data3.writeCharSequence("de3", CharsetUtil.UTF_8); + ByteBuf nextFrame3 = + PayloadFrameCodec.encode(allocator, 1, false, false, true, metadata3, data3); + rule.connection.addToReceivedBuffer(nextFrame1, nextFrame2, nextFrame3); + + rule.connection.addToReceivedBuffer(cancelFrame); + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("encodeDecodePayloadCases") + public void verifiesThatFrameWithNoMetadataHasDecodedCorrectlyIntoPayload( + FrameType frameType, int framesCnt, int responsesCnt) { + ByteBufAllocator allocator = rule.alloc(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(framesCnt); + TestPublisher testPublisher = TestPublisher.create(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.mono(); + } + + @Override + public Flux requestStream(Payload payload) { + Mono.just(payload).subscribe(assertSubscriber); + return testPublisher.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(assertSubscriber); + return testPublisher.flux(); + } + }, + 1); + + rule.sendRequest(1, frameType, ByteBufPayload.create("d")); + + // if responses number is bigger than 1 we have to send one extra requestN + if (responsesCnt > 1) { + rule.connection.addToReceivedBuffer( + RequestNFrameCodec.encode(allocator, 1, responsesCnt - 1)); + } + + // respond with specific number of elements + for (int i = 0; i < responsesCnt; i++) { + testPublisher.next(ByteBufPayload.create("rd" + i)); + } + + // Listen to incoming frames. Valid for RequestChannel case only + if (framesCnt > 1) { + for (int i = 1; i < responsesCnt; i++) { + rule.connection.addToReceivedBuffer( + PayloadFrameCodec.encode( + allocator, + 1, + false, + false, + true, + null, + Unpooled.wrappedBuffer(("d" + (i + 1)).getBytes()))); + } + } + + if (responsesCnt > 0) { + assertThat(rule.connection.getSent().stream().filter(bb -> frameType(bb) != REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe %s frames sent", frameType, responsesCnt) + .hasSize(responsesCnt) + .allMatch(bb -> !FrameHeaderCodec.hasMetadata(bb)); + } + + if (framesCnt > 1) { + assertThat(rule.connection.getSent().stream().filter(bb -> frameType(bb) == REQUEST_N)) + .describedAs( + "Interaction Type :[%s]. Expected to observe single RequestN(%s) frame", + frameType, framesCnt - 1) + .hasSize(1) + .first() + .matches(bb -> RequestNFrameCodec.requestN(bb) == (framesCnt - 1)); + } + + assertThat(rule.connection.getSent()).allMatch(ReferenceCounted::release); + + assertThat(assertSubscriber.awaitAndAssertNextValueCount(framesCnt).values()) + .hasSize(framesCnt) + .allMatch(p -> !p.hasMetadata()) + .allMatch(ReferenceCounted::release); + + rule.assertHasNoLeaks(); + } + + static Stream encodeDecodePayloadCases() { + return Stream.of( + Arguments.of(REQUEST_FNF, 1, 0), + Arguments.of(REQUEST_RESPONSE, 1, 1), + Arguments.of(REQUEST_STREAM, 1, 5), + Arguments.of(REQUEST_CHANNEL, 5, 5)); + } + + @ParameterizedTest + @MethodSource("refCntCases") + public void ensureSendsErrorOnIllegalRefCntPayload(FrameType frameType) { + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Mono.just(invalidPayload); + } + + @Override + public Flux requestStream(Payload payload) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Payload invalidPayload = ByteBufPayload.create("test", "test"); + invalidPayload.release(); + return Flux.just(invalidPayload); + } + }); + + rule.sendRequest(1, frameType); + + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches( + bb -> frameType(bb) == ERROR, + "Expect frame type to be {" + + ERROR + + "} but was {" + + frameType(rule.connection.getSent().iterator().next()) + + "}") + .matches(ByteBuf::release); + } + + private static Stream refCntCases() { + return Stream.of(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + } + + @Test + // see https://github.com/rsocket/rsocket-java/issues/858 + public void testWorkaround858() { + ByteBuf buffer = rule.alloc().buffer(); + buffer.writeCharSequence("test", CharsetUtil.UTF_8); + + TestPublisher testPublisher = TestPublisher.create(); + + rule.setAcceptingSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).doOnNext(ReferenceCounted::release).subscribe(); + + return testPublisher.flux(); + } + }); + + rule.connection.addToReceivedBuffer( + RequestChannelFrameCodec.encodeReleasingPayload( + rule.alloc(), 1, false, 1, ByteBufPayload.create(buffer))); + rule.connection.addToReceivedBuffer( + ErrorFrameCodec.encode(rule.alloc(), 1, new RuntimeException("test"))); + + assertThat(rule.connection.getSent()) + .hasSize(1) + .first() + .matches(bb -> FrameHeaderCodec.frameType(bb) == REQUEST_N) + .matches(ReferenceCounted::release); + + assertThat(rule.socket.isDisposed()).isFalse(); + testPublisher.assertWasCancelled(); + + rule.assertHasNoLeaks(); + } + + static Stream requestCases() { + return Stream.of(REQUEST_FNF, REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + } + + @DisplayName("reassembles payload") + @ParameterizedTest + @MethodSource("requestCases") + void reassemblePayload(FrameType frameType) { + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final Payload randomPayload = randomPayload(rule.allocator); + List fragments = prepareFragments(rule.allocator, mtu, randomPayload, frameType); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(frameType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasNoLeaks(); + if (frameType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + } + + rule.assertHasNoLeaks(); + } + + @DisplayName("reassembles metadata") + @ParameterizedTest + @MethodSource("requestCases") + void reassembleMetadataOnly(FrameType frameType) { + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final Payload randomMetadataOnlyPayload = randomMetadataOnlyPayload(rule.allocator); + List fragments = + prepareFragments(rule.allocator, mtu, randomMetadataOnlyPayload, frameType); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()) + .isEqualTo(randomMetadataOnlyPayload) + .hasNoLeaks(); + randomMetadataOnlyPayload.release(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(frameType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasNoLeaks(); + if (frameType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if reassembling payload size exceeds {0}") + @MethodSource("requestCases") + public void errorTooBigPayload(FrameType frameType) { + final int mtu = ThreadLocalRandom.current().nextInt(64, 256); + final int maxInboundPayloadSize = ThreadLocalRandom.current().nextInt(mtu + 1, 4096); + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setMaxInboundPayloadSize(maxInboundPayloadSize); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + final Payload randomPayload = fixedSizePayload(rule.allocator, maxInboundPayloadSize + 1); + List fragments = prepareFragments(rule.allocator, mtu, randomPayload, frameType); + randomPayload.release(); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()).isNull(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(ERROR) + .hasData( + "Failed to reassemble payload. Cause: " + + String.format(ILLEGAL_REASSEMBLED_PAYLOAD_SIZE, maxInboundPayloadSize)) + .hasNoLeaks(); + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest(name = "throws error if fragment before the last is < min MTU {0}") + @MethodSource("requestCases") + public void errorFragmentTooSmall(FrameType frameType) { + final int mtu = 32; + AtomicReference receivedPayload = new AtomicReference<>(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)); + } + }); + final Payload randomPayload = fixedSizePayload(rule.allocator, 156); + List fragments = prepareFragments(rule.allocator, mtu, randomPayload, frameType); + randomPayload.release(); + + rule.connection.addToReceivedBuffer(fragments.toArray(new ByteBuf[0])); + + PayloadAssert.assertThat(receivedPayload.get()).isNull(); + + if (frameType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(ERROR) + .hasData("Failed to reassemble payload. Cause: Fragment is too small.") + .hasNoLeaks(); + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("requestCases") + void receivingRequestOnStreamIdThaIsAlreadyInUseMUSTBeIgnored_ReassemblyCase( + FrameType requestType) { + AtomicReference receivedPayload = new AtomicReference<>(); + final Sinks.Empty delayer = Sinks.empty(); + rule.setAcceptingSocket( + new RSocket() { + + @Override + public Mono fireAndForget(Payload payload) { + receivedPayload.set(payload); + return delayer.asMono(); + } + + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + }); + final Payload randomPayload1 = fixedSizePayload(rule.allocator, 128); + final List fragments1 = + prepareFragments(rule.allocator, 64, randomPayload1, requestType); + final Payload randomPayload2 = fixedSizePayload(rule.allocator, 128); + final List fragments2 = + prepareFragments(rule.allocator, 64, randomPayload2, requestType); + randomPayload2.release(); + rule.connection.addToReceivedBuffer(fragments1.remove(0)); + rule.connection.addToReceivedBuffer(fragments2.remove(0)); + + rule.connection.addToReceivedBuffer(fragments1.toArray(new ByteBuf[0])); + if (requestType != REQUEST_CHANNEL) { + rule.connection.addToReceivedBuffer(fragments2.toArray(new ByteBuf[0])); + delayer.tryEmitEmpty(); + } else { + delayer.tryEmitEmpty(); + rule.connection.addToReceivedBuffer(PayloadFrameCodec.encodeComplete(rule.allocator, 1)); + rule.connection.addToReceivedBuffer(fragments2.toArray(new ByteBuf[0])); + } + + PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload1).hasNoLeaks(); + randomPayload1.release(); + + if (requestType != REQUEST_FNF) { + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(requestType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasNoLeaks(); + + if (requestType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + } + + rule.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("requestCases") + void receivingRequestOnStreamIdThaIsAlreadyInUseMUSTBeIgnored(FrameType requestType) { + Assumptions.assumeThat(requestType).isNotEqualTo(REQUEST_FNF); + AtomicReference receivedPayload = new AtomicReference<>(); + final Sinks.One delayer = Sinks.one(); + rule.setAcceptingSocket( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + receivedPayload.set(payload); + return Mono.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestStream(Payload payload) { + receivedPayload.set(payload); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads).subscribe(receivedPayload::set, null, null, s -> s.request(1)); + return Flux.just(genericPayload(rule.allocator)).delaySubscription(delayer.asMono()); + } + }); + final Payload randomPayload1 = fixedSizePayload(rule.allocator, 64); + final Payload randomPayload2 = fixedSizePayload(rule.allocator, 64); + rule.sendRequest(1, requestType, randomPayload1.retain()); + rule.sendRequest(1, requestType, randomPayload2); + + delayer.tryEmitEmpty(); + + PayloadAssert.assertThat(receivedPayload.get()).isEqualTo(randomPayload1).hasNoLeaks(); + randomPayload1.release(); + + FrameAssert.assertThat(rule.connection.getSent().poll()) + .typeOf(requestType == REQUEST_RESPONSE ? NEXT_COMPLETE : NEXT) + .hasNoLeaks(); + + if (requestType != REQUEST_RESPONSE) { + FrameAssert.assertThat(rule.connection.getSent().poll()).typeOf(COMPLETE).hasNoLeaks(); + } + + rule.assertHasNoLeaks(); + } + + public static class ServerSocketRule extends AbstractSocketRule { + + private RSocket acceptingSocket; + private volatile int prefetch; + private RequestInterceptor requestInterceptor; + protected Sinks.Empty onCloseSink; + + @Override + protected void doInit() { + acceptingSocket = + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + }; + super.doInit(); + } + + public void setAcceptingSocket(RSocket acceptingSocket) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + this.prefetch = Integer.MAX_VALUE; + super.doInit(); + } + + public void setRequestInterceptor(RequestInterceptor requestInterceptor) { + this.requestInterceptor = requestInterceptor; + super.doInit(); + } + + public void setAcceptingSocket(RSocket acceptingSocket, int prefetch) { + this.acceptingSocket = acceptingSocket; + connection = new TestDuplexConnection(alloc()); + connectSub = TestSubscriber.create(); + this.prefetch = prefetch; + super.doInit(); + } + + @Override + protected RSocketResponder newRSocket() { + onCloseSink = Sinks.empty(); + return new RSocketResponder( + connection, + acceptingSocket, + PayloadDecoder.ZERO_COPY, + null, + 0, + maxFrameLength, + maxInboundPayloadSize, + __ -> requestInterceptor, + onCloseSink); + } + + private void sendRequest(int streamId, FrameType frameType) { + sendRequest(streamId, frameType, EmptyPayload.INSTANCE); + } + + private void sendRequest(int streamId, FrameType frameType, Payload payload) { + ByteBuf request; + + switch (frameType) { + case REQUEST_CHANNEL: + request = + RequestChannelFrameCodec.encodeReleasingPayload( + allocator, streamId, false, prefetch, payload); + break; + case REQUEST_STREAM: + request = + RequestStreamFrameCodec.encodeReleasingPayload( + allocator, streamId, prefetch, payload); + break; + case REQUEST_RESPONSE: + request = RequestResponseFrameCodec.encodeReleasingPayload(allocator, streamId, payload); + break; + case REQUEST_FNF: + request = + RequestFireAndForgetFrameCodec.encodeReleasingPayload(allocator, streamId, payload); + break; + default: + throw new IllegalArgumentException("unsupported type: " + frameType); + } + + connection.addToReceivedBuffer(request); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java new file mode 100644 index 000000000..90e881257 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerFragmentationTest.java @@ -0,0 +1,64 @@ +package io.rsocket.core; + +import io.rsocket.Closeable; +import io.rsocket.FrameAssert; +import io.rsocket.frame.FrameType; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestServerTransport; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class RSocketServerFragmentationTest { + + @Test + public void serverErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketServer.create().fragment(2)) + .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void serverSucceedsWithEnabledFragmentationOnSufficientMtu() { + TestServerTransport transport = new TestServerTransport(); + Closeable closeable = RSocketServer.create().fragment(100).bind(transport).block(); + closeable.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void serverSucceedsWithDisabledFragmentation() { + TestServerTransport transport = new TestServerTransport(); + Closeable closeable = RSocketServer.create().bind(transport).block(); + closeable.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void clientErrorsWithEnabledFragmentationOnInsufficientMtu() { + Assertions.assertThatIllegalArgumentException() + .isThrownBy(() -> RSocketConnector.create().fragment(2)) + .withMessage("The smallest allowed mtu size is 64 bytes, provided: 2"); + } + + @Test + public void clientSucceedsWithEnabledFragmentationOnSufficientMtu() { + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.create().fragment(100).connect(transport).block(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .typeOf(FrameType.SETUP) + .hasNoLeaks(); + transport.testConnection().dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void clientSucceedsWithDisabledFragmentation() { + TestClientTransport transport = new TestClientTransport(); + RSocketConnector.connectWith(transport).block(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .typeOf(FrameType.SETUP) + .hasNoLeaks(); + transport.testConnection().dispose(); + transport.alloc().assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java new file mode 100644 index 000000000..a335ac1f3 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketServerTest.java @@ -0,0 +1,201 @@ +/* + * Copyright 2015-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Closeable; +import io.rsocket.FrameAssert; +import io.rsocket.RSocket; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.RequestResponseFrameCodec; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.test.util.TestServerTransport; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Random; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; + +public class RSocketServerTest { + + @Test + public void unexpectedFramesBeforeSetupFrame() { + TestServerTransport transport = new TestServerTransport(); + RSocketServer.create().bind(transport).block(); + + final TestDuplexConnection duplexConnection = transport.connect(); + + duplexConnection.addToReceivedBuffer( + KeepAliveFrameCodec.encode(duplexConnection.alloc(), false, 1, Unpooled.EMPTY_BUFFER)); + + StepVerifier.create(duplexConnection.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection.pollFrame()) + .isNotNull() + .typeOf(FrameType.ERROR) + .hasData("SETUP or RESUME frame must be received before any others") + .hasStreamIdZero() + .hasNoLeaks(); + duplexConnection.alloc().assertHasNoLeaks(); + } + + @Test + public void timeoutOnNoFirstFrame() { + final VirtualTimeScheduler scheduler = VirtualTimeScheduler.getOrSet(); + TestServerTransport transport = new TestServerTransport(); + try { + RSocketServer.create().maxTimeToFirstFrame(Duration.ofMinutes(2)).bind(transport).block(); + + final TestDuplexConnection duplexConnection = transport.connect(); + + scheduler.advanceTimeBy(Duration.ofMinutes(1)); + + Assertions.assertThat(duplexConnection.isDisposed()).isFalse(); + + scheduler.advanceTimeBy(Duration.ofMinutes(1)); + + StepVerifier.create(duplexConnection.onClose()) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(10)); + + FrameAssert.assertThat(duplexConnection.pollFrame()).isNull(); + } finally { + transport.alloc().assertHasNoLeaks(); + VirtualTimeScheduler.reset(); + } + } + + @Test + public void ensuresMaxFrameLengthCanNotBeLessThenMtu() { + RSocketServer.create() + .fragment(128) + .bind(new TestServerTransport().withMaxFrameLength(64)) + .as(StepVerifier::create) + .expectErrorMessage( + "Configured maximumTransmissionUnit[128] exceeds configured maxFrameLength[64]") + .verify(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPayloadSize() { + RSocketServer.create() + .maxInboundPayloadSize(128) + .bind(new TestServerTransport().withMaxFrameLength(256)) + .as(StepVerifier::create) + .expectErrorMessage("Configured maxFrameLength[256] exceeds maxPayloadSize[128]") + .verify(); + } + + @Test + public void ensuresMaxFrameLengthCanNotBeGreaterThenMaxPossibleFrameLength() { + RSocketServer.create() + .bind(new TestServerTransport().withMaxFrameLength(Integer.MAX_VALUE)) + .as(StepVerifier::create) + .expectErrorMessage( + "Configured maxFrameLength[" + + Integer.MAX_VALUE + + "] " + + "exceeds maxFrameLength limit " + + FRAME_LENGTH_MASK) + .verify(); + } + + @Test + public void unexpectedFramesBeforeSetup() { + Sinks.Empty connectedSink = Sinks.empty(); + + TestServerTransport transport = new TestServerTransport(); + Closeable server = + RSocketServer.create() + .acceptor( + (setup, sendingSocket) -> { + connectedSink.tryEmitEmpty(); + return Mono.just(new RSocket() {}); + }) + .bind(transport) + .block(); + + byte[] bytes = new byte[16_000_000]; + new Random().nextBytes(bytes); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer( + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.EMPTY_BUFFER, + ByteBufAllocator.DEFAULT.buffer(bytes.length).writeBytes(bytes))); + + StepVerifier.create(connection.onClose()).expectComplete().verify(Duration.ofSeconds(30)); + assertThat(connectedSink.scan(Scannable.Attr.TERMINATED)) + .as("Connection should not succeed") + .isFalse(); + FrameAssert.assertThat(connection.pollFrame()) + .hasStreamIdZero() + .hasData("SETUP or RESUME frame must be received before any others") + .hasNoLeaks(); + server.dispose(); + transport.alloc().assertHasNoLeaks(); + } + + @Test + public void ensuresErrorFrameDeliveredPriorConnectionDisposal() { + TestServerTransport transport = new TestServerTransport(); + Closeable server = + RSocketServer.create() + .acceptor( + (setup, sendingSocket) -> Mono.error(new RejectedSetupException("ACCESS_DENIED"))) + .bind(transport) + .block(); + + TestDuplexConnection connection = transport.connect(); + connection.addToReceivedBuffer( + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, + false, + 0, + 1, + Unpooled.EMPTY_BUFFER, + "metadata_type", + "data_type", + EmptyPayload.INSTANCE)); + + StepVerifier.create(connection.onClose()).expectComplete().verify(Duration.ofSeconds(30)); + FrameAssert.assertThat(connection.pollFrame()) + .hasStreamIdZero() + .hasData("ACCESS_DENIED") + .hasNoLeaks(); + server.dispose(); + transport.alloc().assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java new file mode 100644 index 000000000..e01e6ebdc --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RSocketTest.java @@ -0,0 +1,605 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.exceptions.CustomRSocketException; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.LocalDuplexConnection; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.atomic.AtomicReference; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.reactivestreams.Publisher; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; + +public class RSocketTest { + + public final SocketRule rule = new SocketRule(); + + @BeforeEach + public void setup() { + rule.init(); + } + + @AfterEach + public void tearDownAndCheckOnLeaks() { + rule.alloc().assertHasNoLeaks(); + } + + @Test + public void rsocketDisposalShouldEndupWithNoErrorsOnClose() { + RSocket requestHandlingRSocket = + new RSocket() { + final Disposable disposable = Disposables.single(); + + @Override + public void dispose() { + disposable.dispose(); + } + + @Override + public boolean isDisposed() { + return disposable.isDisposed(); + } + }; + rule.setRequestAcceptor(requestHandlingRSocket); + rule.crs + .onClose() + .as(StepVerifier::create) + .expectSubscription() + .then(rule.crs::dispose) + .expectComplete() + .verify(Duration.ofMillis(100)); + + Assertions.assertThat(requestHandlingRSocket.isDisposed()).isTrue(); + } + + @Test + @Timeout(2_000) + public void testRequestReplyNoError() { + StepVerifier.create(rule.crs.requestResponse(DefaultPayload.create("hello"))) + .expectNextCount(1) + .expectComplete() + .verify(); + } + + @Test + @Timeout(2000) + public void testHandlerEmitsError() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error(new NullPointerException("Deliberate exception.")); + } + }); + rule.crs + .requestResponse(EmptyPayload.INSTANCE) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(ApplicationErrorException.class) + .hasMessage("Deliberate exception.")) + .verify(Duration.ofMillis(100)); + } + + @Test + @Timeout(2000) + public void testHandlerEmitsCustomError() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error( + new CustomRSocketException(0x00000501, "Deliberate Custom exception.")); + } + }); + rule.crs + .requestResponse(EmptyPayload.INSTANCE) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("Deliberate Custom exception.") + .hasFieldOrPropertyWithValue("errorCode", 0x00000501)) + .verify(); + } + + @Test + @Timeout(2000) + public void testRequestPropagatesCorrectlyForRequestChannel() { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + // specifically limits request to 3 in order to prevent 256 request from limitRate + // hidden on the responder side + .take(3, true); + } + }); + + Flux.range(0, 3) + .map(i -> DefaultPayload.create("" + i)) + .as(rule.crs::requestChannel) + .as(publisher -> StepVerifier.create(publisher, 3)) + .expectSubscription() + .expectNextCount(3) + .expectComplete() + .verify(Duration.ofMillis(5000)); + } + + @Test + @Timeout(2000) + public void testStream() { + Flux responses = rule.crs.requestStream(DefaultPayload.create("Payload In")); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test + @Timeout(200000) + public void testChannel() { + Flux requests = + Flux.range(0, 10).map(i -> DefaultPayload.create("streaming in -> " + i)); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectNextCount(10).expectComplete().verify(); + } + + @Test + @Timeout(2000) + public void testErrorPropagatesCorrectly() { + AtomicReference error = new AtomicReference<>(); + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads).doOnError(error::set); + } + }); + Flux requests = Flux.error(new RuntimeException("test")); + Flux responses = rule.crs.requestChannel(requests); + StepVerifier.create(responses).expectErrorMessage("test").verify(); + Assertions.assertThat(error.get()).isNull(); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion1() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void requestChannelCase_StreamIsTerminatedAfterBothSidesSentCompletion2() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + completeFromRequesterPublisher(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_CancellationFromResponderShouldLeaveStreamInHalfClosedStateWithNextCompletionPossibleFromRequester() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + } + + @Test + public void + requestChannelCase_CompletionFromRequesterShouldLeaveStreamInHalfClosedStateWithNextCancellationPossibleFromResponder() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + completeFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + cancelFromResponderSubscriber(requesterPublisher, responderSubscriber); + } + + @Test + public void + requestChannelCase_ensureThatRequesterSubscriberCancellationTerminatesStreamsOnBothSides() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + // ensures both sides are terminated + cancelFromRequesterSubscriber( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + } + + @Test + public void requestChannelCase_ErrorFromResponderShouldTerminatesStreamsOnBothSides() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + // ensures both sides are terminated + errorFromResponderPublisher( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + } + + @Test + public void requestChannelCase_ErrorFromRequesterShouldTerminatesStreamsOnBothSides() { + TestPublisher requesterPublisher = TestPublisher.create(); + AssertSubscriber requesterSubscriber = new AssertSubscriber<>(0); + + AssertSubscriber responderSubscriber = new AssertSubscriber<>(0); + TestPublisher responderPublisher = TestPublisher.create(); + + initRequestChannelCase( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + + nextFromResponderPublisher(responderPublisher, requesterSubscriber); + + nextFromRequesterPublisher(requesterPublisher, responderSubscriber); + + // ensures both sides are terminated + errorFromRequesterPublisher( + requesterPublisher, requesterSubscriber, responderPublisher, responderSubscriber); + } + + void initRequestChannelCase( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + rule.setRequestAcceptor( + new RSocket() { + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(responderSubscriber); + return responderPublisher.flux(); + } + }); + + rule.crs.requestChannel(requesterPublisher).subscribe(requesterSubscriber); + + requesterPublisher.assertWasSubscribed(); + requesterSubscriber.assertSubscribed(); + + responderSubscriber.assertNotSubscribed(); + responderPublisher.assertWasNotSubscribed(); + + // firstRequest + requesterSubscriber.request(1); + requesterPublisher.assertMaxRequested(1); + requesterPublisher.next(DefaultPayload.create("initialData", "initialMetadata")); + + responderSubscriber.assertSubscribed(); + responderPublisher.assertWasSubscribed(); + } + + void nextFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that outerUpstream and innerSubscriber is not terminated so the requestChannel + requesterPublisher.assertSubscribers(1); + responderSubscriber.assertNotTerminated(); + + responderSubscriber.request(6); + requesterPublisher.next( + DefaultPayload.create("d1", "m1"), + DefaultPayload.create("d2"), + DefaultPayload.create("d3", "m3"), + DefaultPayload.create("d4"), + DefaultPayload.create("d5", "m5")); + + List innerPayloads = responderSubscriber.awaitAndAssertNextValueCount(6).values(); + Assertions.assertThat(innerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("initialData", "d1", "d2", "d3", "d4", "d5"); + Assertions.assertThat(innerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, true, false, true, false, true); + Assertions.assertThat(innerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("initialMetadata", "m1", "", "m3", "", "m5"); + } + + void completeFromRequesterPublisher( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + requesterPublisher.complete(); + responderSubscriber.assertTerminated(); + requesterPublisher.assertNoSubscribers(); + } + + void cancelFromResponderSubscriber( + TestPublisher requesterPublisher, AssertSubscriber responderSubscriber) { + // ensures that after sending complete upstream part is closed + responderSubscriber.cancel(); + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + void nextFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that downstream is not terminated so the requestChannel state is half-closed + responderPublisher.assertSubscribers(1); + requesterSubscriber.assertNotTerminated(); + + // ensures responderPublisher can send messages and outerSubscriber can receive them + requesterSubscriber.request(5); + responderPublisher.next( + DefaultPayload.create("rd1", "rm1"), + DefaultPayload.create("rd2"), + DefaultPayload.create("rd3", "rm3"), + DefaultPayload.create("rd4"), + DefaultPayload.create("rd5", "rm5")); + + List outerPayloads = requesterSubscriber.awaitAndAssertNextValueCount(5).values(); + Assertions.assertThat(outerPayloads.stream().map(Payload::getDataUtf8)) + .containsExactly("rd1", "rd2", "rd3", "rd4", "rd5"); + Assertions.assertThat(outerPayloads.stream().map(Payload::hasMetadata)) + .containsExactly(true, false, true, false, true); + Assertions.assertThat(outerPayloads.stream().map(Payload::getMetadataUtf8)) + .containsExactly("rm1", "", "rm3", "", "rm5"); + } + + void completeFromResponderPublisher( + TestPublisher responderPublisher, AssertSubscriber requesterSubscriber) { + // ensures that after sending complete inner upstream is closed + responderPublisher.complete(); + requesterSubscriber.assertTerminated(); + responderPublisher.assertNoSubscribers(); + } + + void cancelFromRequesterSubscriber( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + // ensures that after sending cancel the whole requestChannel is terminated + requesterSubscriber.cancel(); + // error should be propagated + responderSubscriber.assertTerminated(); + responderPublisher.assertWasCancelled(); + responderPublisher.assertNoSubscribers(); + // ensures that cancellation is propagated to the actual upstream + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + static final CustomRSocketException EXCEPTION = new CustomRSocketException(123456, "test"); + + void errorFromResponderPublisher( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + // ensures that after sending cancel the whole requestChannel is terminated + responderPublisher.error(EXCEPTION); + // error should be propagated + responderSubscriber.assertTerminated().assertError(CancellationException.class); + requesterSubscriber + .assertTerminated() + .assertError(CustomRSocketException.class) + .assertErrorMessage("test"); + // ensures that cancellation is propagated to the actual upstream + requesterPublisher.assertWasCancelled(); + requesterPublisher.assertNoSubscribers(); + } + + void errorFromRequesterPublisher( + TestPublisher requesterPublisher, + AssertSubscriber requesterSubscriber, + TestPublisher responderPublisher, + AssertSubscriber responderSubscriber) { + // ensures that after sending cancel the whole requestChannel is terminated + requesterPublisher.error(EXCEPTION); + // error should be propagated + responderSubscriber + .assertTerminated() + .assertError(CustomRSocketException.class) + .assertErrorMessage("test"); + requesterSubscriber + .assertTerminated() + .assertError(CustomRSocketException.class) + .assertErrorMessage("test"); + + // ensures that cancellation is propagated to the actual upstream + responderPublisher.assertWasCancelled(); + responderPublisher.assertNoSubscribers(); + } + + public static class SocketRule { + + Sinks.Many serverProcessor; + Sinks.Many clientProcessor; + private RSocketRequester crs; + + @SuppressWarnings("unused") + private RSocketResponder srs; + + private RSocket requestAcceptor; + + private LeaksTrackingByteBufAllocator allocator; + protected Sinks.Empty thisClosedSink; + protected Sinks.Empty otherClosedSink; + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + public void init() { + allocator = LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + serverProcessor = Sinks.many().multicast().directBestEffort(); + clientProcessor = Sinks.many().multicast().directBestEffort(); + + this.thisClosedSink = Sinks.empty(); + this.otherClosedSink = Sinks.empty(); + + LocalDuplexConnection serverConnection = + new LocalDuplexConnection("server", allocator, clientProcessor, serverProcessor); + LocalDuplexConnection clientConnection = + new LocalDuplexConnection("client", allocator, serverProcessor, clientProcessor); + + clientConnection.onClose().doFinally(__ -> serverConnection.dispose()).subscribe(); + serverConnection.onClose().doFinally(__ -> clientConnection.dispose()).subscribe(); + + requestAcceptor = + null != requestAcceptor + ? requestAcceptor + : new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.range(1, 10) + .map(i -> DefaultPayload.create("server got -> [" + payload + "]")); + } + + @Override + public Flux requestChannel(Publisher payloads) { + Flux.from(payloads) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")) + .subscribe(); + + return Flux.range(1, 10) + .map( + payload -> + DefaultPayload.create("server got -> [" + payload.toString() + "]")); + } + }; + + srs = + new RSocketResponder( + serverConnection, + requestAcceptor, + PayloadDecoder.DEFAULT, + null, + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + __ -> null, + otherClosedSink); + + crs = + new RSocketRequester( + clientConnection, + PayloadDecoder.DEFAULT, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + null, + thisClosedSink, + otherClosedSink.asMono().and(thisClosedSink.asMono())); + } + + public void setRequestAcceptor(RSocket requestAcceptor) { + this.requestAcceptor = requestAcceptor; + init(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java new file mode 100644 index 000000000..3112a0943 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ReconnectMonoTests.java @@ -0,0 +1,1108 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.awaitility.Awaitility.await; + +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.io.IOException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeoutException; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Exceptions; +import reactor.core.Scannable; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.function.Tuple2; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class ReconnectMonoTests { + + private Queue retries = new ConcurrentLinkedQueue<>(); + private Queue> received = new ConcurrentLinkedQueue<>(); + private Queue expired = new ConcurrentLinkedQueue<>(); + + @Test + public void shouldExpireValueOnRacingDisposeAndNext() { + Hooks.onErrorDropped(t -> {}); + Hooks.onNextDropped(System.out::println); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + final CoreSubscriber[] monoSubscribers = new CoreSubscriber[1]; + Subscription mockSubscription = Mockito.mock(Subscription.class); + final Mono stringMono = + new Mono() { + @Override + public void subscribe(CoreSubscriber actual) { + actual.onSubscribe(mockSubscription); + monoSubscribers[0] = actual; + } + }; + + final ReconnectMono reconnectMono = + stringMono + .doOnDiscard(Object.class, System.out::println) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + RaceTestUtils.race(() -> monoSubscribers[0].onNext("value" + index), reconnectMono::dispose); + + monoSubscribers[0].onComplete(); + + subscriber.assertTerminated(); + Mockito.verify(mockSubscription).cancel(); + + if (!subscriber.errors().isEmpty()) { + subscriber + .assertError(CancellationException.class) + .assertErrorMessage("ReconnectMono has already been disposed"); + + assertThat(expired).containsOnly("value" + i); + } else { + subscriber.assertValues("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, () -> reconnectMono.subscribe(raceSubscriber)); + + subscriber.assertTerminated(); + subscriber.assertValues("value" + i); + raceSubscriber.assertValues("value" + i); + + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); + + assertThat( + reconnectMono.resolvingInner.add( + new ResolvingOperator.MonoDeferredResolutionOperator<>( + reconnectMono.resolvingInner, subscriber))) + .isEqualTo(ResolvingOperator.READY_STATE); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); + reconnectMono.resolvingInner.mainSubscriber.onComplete(); + + RaceTestUtils.race( + reconnectMono::invalidate, + () -> { + reconnectMono.subscribe(raceSubscriber); + if (!raceSubscriber.isTerminated()) { + reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_not_expire" + index); + reconnectMono.resolvingInner.mainSubscriber.onComplete(); + } + }); + + subscriber.assertTerminated(); + subscriber.assertValues("value_to_expire" + i); + + raceSubscriber.assertComplete(); + String v = raceSubscriber.values().get(0); + if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { + assertThat(v).isEqualTo("value_to_not_expire" + index); + } else { + assertThat(v).isEqualTo("value_to_expire" + index); + } + + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { + assertThat(received) + .hasSize(2) + .containsExactly( + Tuples.of("value_to_expire" + i, reconnectMono), + Tuples.of("value_to_not_expire" + i, reconnectMono)); + } else { + assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + reconnectMono.resolvingInner.mainSubscriber.onNext("value_to_expire" + i); + reconnectMono.resolvingInner.mainSubscriber.onComplete(); + + RaceTestUtils.race( + reconnectMono::invalidate, + reconnectMono::invalidate, + () -> { + reconnectMono.subscribe(raceSubscriber); + if (!raceSubscriber.isTerminated()) { + reconnectMono.resolvingInner.mainSubscriber.onNext( + "value_to_possibly_expire" + index); + reconnectMono.resolvingInner.mainSubscriber.onComplete(); + } + }); + + subscriber.assertTerminated(); + subscriber.assertValues("value_to_expire" + i); + + raceSubscriber.assertComplete(); + assertThat(raceSubscriber.values().get(0)) + .isIn("value_to_possibly_expire" + index, "value_to_expire" + index); + + if (expired.size() == 2) { + assertThat(expired) + .hasSize(2) + .containsExactly("value_to_expire" + i, "value_to_possibly_expire" + i); + } else { + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + } + if (received.size() == 2) { + assertThat(received) + .hasSize(2) + .containsExactly( + Tuples.of("value_to_expire" + i, reconnectMono), + Tuples.of("value_to_possibly_expire" + i, reconnectMono)); + } else { + assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + final Mono source = + Mono.fromSupplier( + new Supplier() { + boolean once = false; + + @Override + public String get() { + + if (!once) { + once = true; + return "value_to_expire" + index; + } + + return "value_to_not_expire" + index; + } + }); + + final ReconnectMono reconnectMono = + new ReconnectMono<>( + source.subscribeOn(Schedulers.boundedElastic()), onExpire(), onValue()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + subscriber.await().assertComplete(); + + assertThat(expired).isEmpty(); + + RaceTestUtils.race( + () -> + assertThat(reconnectMono.block()) + .matches( + (v) -> + v.equals("value_to_not_expire" + index) + || v.equals("value_to_expire" + index)), + reconnectMono::invalidate); + + subscriber.assertTerminated(); + + subscriber.assertValues("value_to_expire" + i); + + assertThat(expired).hasSize(1).containsOnly("value_to_expire" + i); + if (reconnectMono.resolvingInner.subscribers == ResolvingOperator.READY) { + await().atMost(Duration.ofSeconds(5)).until(() -> received.size() == 2); + assertThat(received) + .hasSize(2) + .containsExactly( + Tuples.of("value_to_expire" + i, reconnectMono), + Tuples.of("value_to_not_expire" + i, reconnectMono)); + } else { + assertThat(received) + .hasSize(1) + .containsOnly(Tuples.of("value_to_expire" + i, reconnectMono)); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = new AssertSubscriber<>(); + final AssertSubscriber raceSubscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + assertThat(cold.subscribeCount()).isZero(); + + RaceTestUtils.race( + () -> reconnectMono.subscribe(subscriber), () -> reconnectMono.subscribe(raceSubscriber)); + + subscriber.assertTerminated(); + assertThat(raceSubscriber.isTerminated()).isTrue(); + + subscriber.assertValues("value" + i); + raceSubscriber.assertValues("value" + i); + + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); + + assertThat(cold.subscribeCount()).isOne(); + + assertThat( + reconnectMono.resolvingInner.add( + new ResolvingOperator.MonoDeferredResolutionOperator<>( + reconnectMono.resolvingInner, subscriber))) + .isEqualTo(ResolvingOperator.READY_STATE); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = new AssertSubscriber<>(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + assertThat(cold.subscribeCount()).isZero(); + + String[] values = new String[1]; + + RaceTestUtils.race( + () -> values[0] = reconnectMono.block(timeout), + () -> reconnectMono.subscribe(subscriber)); + + subscriber.assertTerminated(); + + subscriber.assertValues("value" + i); + assertThat(values).containsExactly("value" + i); + + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); + + assertThat(cold.subscribeCount()).isOne(); + + assertThat( + reconnectMono.resolvingInner.add( + new ResolvingOperator.MonoDeferredResolutionOperator<>( + reconnectMono.resolvingInner, subscriber))) + .isEqualTo(ResolvingOperator.READY_STATE); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value" + i); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + assertThat(cold.subscribeCount()).isZero(); + + String[] values1 = new String[1]; + String[] values2 = new String[1]; + + RaceTestUtils.race( + () -> values1[0] = reconnectMono.block(timeout), + () -> values2[0] = reconnectMono.block(timeout)); + + assertThat(values2).containsExactly("value" + i); + assertThat(values1).containsExactly("value" + i); + + assertThat(reconnectMono.resolvingInner.subscribers).isEqualTo(ResolvingOperator.READY); + + assertThat(cold.subscribeCount()).isOne(); + + assertThat( + reconnectMono.resolvingInner.add( + new ResolvingOperator.MonoDeferredResolutionOperator<>( + reconnectMono.resolvingInner, new AssertSubscriber<>()))) + .isEqualTo(ResolvingOperator.READY_STATE); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndNoValueComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + subscriber.assertTerminated(); + + Throwable error = subscriber.errors().get(0); + + if (error instanceof CancellationException) { + assertThat(error) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + assertThat(error) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Source completed empty"); + } + + assertThat(expired).isEmpty(); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndComplete() { + Hooks.onErrorDropped(t -> {}); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(cold::complete, reconnectMono::dispose); + + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + assertThat(subscriber.errors().get(0)) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); + } + + assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndError() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + Throwable error = subscriber.errors().get(0); + if (error instanceof CancellationException) { + assertThat(error) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + assertThat(error).isInstanceOf(RuntimeException.class).hasMessage("test"); + } + } else { + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); + } + + assertThat(expired).hasSize(1).containsOnly("value" + i); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndErrorWithNoBackoff() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + cold.mono() + .retryWhen(Retry.max(1).filter(t -> t instanceof Exception)) + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + + cold.next("value" + i); + + RaceTestUtils.race(() -> cold.error(runtimeException), reconnectMono::dispose); + + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + Throwable error = subscriber.errors().get(0); + if (error instanceof CancellationException) { + assertThat(error) + .isInstanceOf(CancellationException.class) + .hasMessage("ReconnectMono has already been disposed"); + } else { + assertThat(error).matches(Exceptions::isRetryExhausted).hasCause(runtimeException); + } + + assertThat(expired).hasSize(1).containsOnly("value" + i); + } else { + assertThat(received).hasSize(1).containsOnly(Tuples.of("value" + i, reconnectMono)); + subscriber.assertValues("value" + i); + } + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldThrowOnBlocking() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on Mono blocking read"); + } + + @Test + public void shouldThrowOnBlockingIfHasAlreadyTerminated() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + publisher.error(new RuntimeException("test")); + + Assertions.assertThatThrownBy(() -> reconnectMono.block(Duration.ofMillis(100))) + .isInstanceOf(RuntimeException.class) + .hasMessage("test") + .hasSuppressedException(new Exception("Terminated with an error")); + } + + @Test + public void shouldBeScannable() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final Mono parent = publisher.mono(); + final ReconnectMono reconnectMono = + parent.as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final Scannable scannableOfReconnect = Scannable.from(reconnectMono); + + assertThat( + (List) + scannableOfReconnect.parents().map(s -> s.getClass()).collect(Collectors.toList())) + .hasSize(1) + .containsExactly(publisher.mono().getClass()); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)).isEqualTo(false); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)).isNull(); + + final AssertSubscriber subscriber = + reconnectMono.subscribeWith(new AssertSubscriber<>()); + + final Scannable scannableOfMonoProcessor = Scannable.from(subscriber); + + assertThat( + (List) + scannableOfMonoProcessor + .parents() + .map(s -> s.getClass()) + .collect(Collectors.toList())) + .hasSize(4) + .containsExactly( + ResolvingOperator.MonoDeferredResolutionOperator.class, + ReconnectMono.ResolvingInner.class, + ReconnectMono.class, + publisher.mono().getClass()); + + reconnectMono.dispose(); + + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.TERMINATED)).isEqualTo(true); + assertThat(scannableOfReconnect.scanUnsafe(Scannable.Attr.ERROR)) + .isInstanceOf(CancellationException.class); + } + + @Test + public void shouldNotExpiredIfNotCompleted() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + AssertSubscriber subscriber = new AssertSubscriber<>(); + + reconnectMono.subscribe(subscriber); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + publisher.next("test"); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + reconnectMono.invalidate(); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + publisher.assertSubscribers(1); + assertThat(publisher.subscribeCount()).isEqualTo(1); + + publisher.complete(); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + subscriber.assertTerminated(); + + publisher.assertSubscribers(0); + assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldNotEmitUntilCompletion() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + AssertSubscriber subscriber = new AssertSubscriber<>(); + + reconnectMono.subscribe(subscriber); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + publisher.next("test"); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + publisher.complete(); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + subscriber.assertTerminated(); + subscriber.assertValues("test"); + } + + @Test + public void shouldBePossibleToRemoveThemSelvesFromTheList_CancellationTest() { + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.REQUEST_OVERFLOW); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + AssertSubscriber subscriber = new AssertSubscriber<>(); + + reconnectMono.subscribe(subscriber); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + publisher.next("test"); + + assertThat(expired).isEmpty(); + assertThat(received).isEmpty(); + assertThat(subscriber.isTerminated()).isFalse(); + + subscriber.cancel(); + + assertThat(reconnectMono.resolvingInner.subscribers) + .isEqualTo(ResolvingOperator.EMPTY_SUBSCRIBED); + + publisher.complete(); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + assertThat(subscriber.values()).isEmpty(); + } + + @Test + public void shouldExpireValueOnDispose() { + final TestPublisher publisher = TestPublisher.create(); + // given + final int timeout = 10; + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono) + .expectSubscription() + .then(() -> publisher.next("value")) + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1); + + reconnectMono.dispose(); + + assertThat(expired).hasSize(1); + assertThat(received).hasSize(1); + assertThat(reconnectMono.isDisposed()).isTrue(); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectError(CancellationException.class) + .verify(Duration.ofSeconds(timeout)); + } + + @Test + public void shouldNotifyAllTheSubscribers() { + final TestPublisher publisher = TestPublisher.create(); + + final ReconnectMono reconnectMono = + publisher.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + final AssertSubscriber sub1 = new AssertSubscriber<>(); + final AssertSubscriber sub2 = new AssertSubscriber<>(); + final AssertSubscriber sub3 = new AssertSubscriber<>(); + final AssertSubscriber sub4 = new AssertSubscriber<>(); + + reconnectMono.subscribe(sub1); + reconnectMono.subscribe(sub2); + reconnectMono.subscribe(sub3); + reconnectMono.subscribe(sub4); + + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(4); + + final ArrayList> subscribers = new ArrayList<>(200); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final AssertSubscriber subA = new AssertSubscriber<>(); + final AssertSubscriber subB = new AssertSubscriber<>(); + subscribers.add(subA); + subscribers.add(subB); + RaceTestUtils.race(() -> reconnectMono.subscribe(subA), () -> reconnectMono.subscribe(subB)); + } + + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(RaceTestConstants.REPEATS * 2 + 4); + + sub1.cancel(); + + assertThat(reconnectMono.resolvingInner.subscribers).hasSize(RaceTestConstants.REPEATS * 2 + 3); + + publisher.next("value"); + + assertThat(sub1.scan(Scannable.Attr.CANCELLED)).isTrue(); + assertThat(sub2.values().get(0)).isEqualTo("value"); + assertThat(sub3.values().get(0)).isEqualTo("value"); + assertThat(sub4.values().get(0)).isEqualTo("value"); + + for (AssertSubscriber sub : subscribers) { + assertThat(sub.values().get(0)).isEqualTo("value"); + assertThat(sub.isTerminated()).isTrue(); + } + + assertThat(publisher.subscribeCount()).isEqualTo(1); + } + + @Test + public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value"); + cold.complete(); + final int timeout = 10; + + final ReconnectMono reconnectMono = + cold.flux() + .takeLast(1) + .next() + .as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::invalidate); + + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + cold.next("value2"); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectNext("value2") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received) + .hasSize(2) + .containsOnly(Tuples.of("value", reconnectMono), Tuples.of("value2", reconnectMono)); + + assertThat(cold.subscribeCount()).isEqualTo(2); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestPublisher cold = TestPublisher.createCold(); + cold.next("value"); + final int timeout = 10000; + + final ReconnectMono reconnectMono = + cold.mono().as(source -> new ReconnectMono<>(source, onExpire(), onValue())); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectNext("value") + .expectComplete() + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).isEmpty(); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + RaceTestUtils.race(reconnectMono::invalidate, reconnectMono::dispose); + + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectError(CancellationException.class) + .verify(Duration.ofSeconds(timeout)); + + assertThat(expired).hasSize(1).containsOnly("value"); + assertThat(received).hasSize(1).containsOnly(Tuples.of("value", reconnectMono)); + + assertThat(cold.subscribeCount()).isEqualTo(1); + + expired.clear(); + received.clear(); + } + } + + @Test + public void shouldTimeoutRetryWithVirtualTime() { + // given + final int minBackoff = 1; + final int maxBackoff = 5; + final int timeout = 10; + + // then + StepVerifier.withVirtualTime( + () -> + Mono.error(new RuntimeException("Something went wrong")) + .retryWhen( + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(minBackoff)) + .doAfterRetry(onRetry()) + .maxBackoff(Duration.ofSeconds(maxBackoff))) + .timeout(Duration.ofSeconds(timeout)) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())) + .subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .thenAwait(Duration.ofSeconds(timeout)) + .expectError(TimeoutException.class) + .verify(Duration.ofSeconds(timeout)); + + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + } + + @Test + public void ensuresThatMainSubscriberAllowsOnlyTerminationWithValue() { + final int timeout = 10; + final ReconnectMono reconnectMono = + new ReconnectMono<>(Mono.empty(), onExpire(), onValue()); + + StepVerifier.create(reconnectMono.subscribeOn(Schedulers.boundedElastic())) + .expectSubscription() + .expectErrorSatisfies( + t -> + assertThat(t) + .hasMessage("Source completed empty") + .isInstanceOf(IllegalStateException.class)) + .verify(Duration.ofSeconds(timeout)); + } + + @Test + public void monoRetryNoBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.max(2).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.create(mono).verifyErrorMatches(Exceptions::isRetryExhausted); + assertRetries(IOException.class, IOException.class); + + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryFixedBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen(Retry.fixedDelay(1, Duration.ofMillis(500)).doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .expectNoEvent(Duration.ofMillis(300)) + .thenAwait(Duration.ofMillis(300)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class); + + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + } + + @Test + public void monoRetryExponentialBackoff() { + Mono mono = + Mono.error(new IOException()) + .retryWhen( + Retry.backoff(4, Duration.ofMillis(100)) + .maxBackoff(Duration.ofMillis(500)) + .jitter(0.0d) + .doAfterRetry(onRetry())) + .as(m -> new ReconnectMono<>(m, onExpire(), onValue())); + + StepVerifier.withVirtualTime(() -> mono) + .expectSubscription() + .thenAwait(Duration.ofMillis(100)) + .thenAwait(Duration.ofMillis(200)) + .thenAwait(Duration.ofMillis(400)) + .thenAwait(Duration.ofMillis(500)) + .verifyErrorMatches(Exceptions::isRetryExhausted); + + assertRetries(IOException.class, IOException.class, IOException.class, IOException.class); + + assertThat(received).isEmpty(); + assertThat(expired).isEmpty(); + } + + Consumer onRetry() { + return context -> retries.add(context); + } + + BiConsumer onValue() { + return (v, __) -> received.add(Tuples.of(v, __)); + } + + Consumer onExpire() { + return (v) -> expired.add(v); + } + + @SafeVarargs + private final void assertRetries(Class... exceptions) { + assertThat(retries.size()).isEqualTo(exceptions.length); + int index = 0; + for (Iterator it = retries.iterator(); it.hasNext(); ) { + Retry.RetrySignal retryContext = it.next(); + assertThat(retryContext.totalRetries()).isEqualTo(index); + assertThat(retryContext.failure().getClass()).isEqualTo(exceptions[index]); + index++; + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java new file mode 100644 index 000000000..c1e0a6876 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelRequesterFluxTest.java @@ -0,0 +1,845 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.CANCEL; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Signal; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RequestChannelRequesterFluxTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(10); + + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + + stateAssert.hasSubscribedFlag().hasRequestN(10).hasNoFirstFrameSentFlag(); + + publisher.assertMaxRequested(1).next(payload); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(10).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(10) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check. Request N Frame should sent so request field should be 0 + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(11).hasFirstFrameSentFlag(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + assertSubscriber.request(6); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(nextPayload); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + + ByteBuf firstFragment = fragments.remove(0); + requestChannelRequesterFlux.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollows = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestChannelRequesterFlux.handleNext(followingFragment, hasFollows, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + if (completionCase.equals("inbound")) { + requestChannelRequesterFlux.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } else if (completionCase.equals("outbound")) { + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasOutboundTerminated(); + + requestChannelRequesterFlux.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + } + + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void streamShouldErrorWithoutInitializingRemoteStreamIfSourceIsEmpty(boolean doRequest) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + if (doRequest) { + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + } + + publisher.complete(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Empty Source"); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + public void streamShouldPropagateErrorWithoutInitializingRemoteStreamIfTheFirstSignalIsError( + boolean doRequest) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + if (doRequest) { + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + } + + publisher.error(new RuntimeException("test")); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + assertSubscriber + .assertTerminated() + .assertError(RuntimeException.class) + .assertErrorMessage("test"); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void streamShouldBeInHalfClosedStateOnTheInboundCancellation(String terminationMode) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload3 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + requestChannelRequesterFlux.handleRequestN(10); + publisher.assertMaxRequested(10); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + publisher.assertMaxRequested(Long.MAX_VALUE); + + publisher.next(payload2.retain(), payload3.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload2) + .hasNoLeaks(); + payload2.release(); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload3) + .hasNoLeaks(); + payload3.release(); + + if (terminationMode.equals("outbound")) { + requestChannelRequesterFlux.handleCancel(); + + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasOutboundTerminated(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + requestChannelRequesterFlux.handleComplete(); + } else if (terminationMode.equals("inbound")) { + requestChannelRequesterFlux.handleComplete(); + + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasInboundTerminated(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + requestChannelRequesterFlux.handleCancel(); + } + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + } + + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound"}) + public void errorShouldTerminateExecution(String terminationMode) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload payload3 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + requestChannelRequesterFlux.handleRequestN(10); + publisher.assertMaxRequested(10); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + publisher.assertMaxRequested(Long.MAX_VALUE); + + publisher.next(payload2.retain(), payload3.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload2) + .hasNoLeaks(); + payload2.release(); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.NEXT) + .hasPayload(payload3) + .hasNoLeaks(); + payload3.release(); + + if (terminationMode.equals("outbound")) { + publisher.error(new ApplicationErrorException("test")); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.ERROR) + .hasData("test") + .hasNoLeaks(); + } else if (terminationMode.equals("inbound")) { + requestChannelRequesterFlux.handleError(new ApplicationErrorException("test")); + publisher.assertWasCancelled(); + } + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + } + + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestChannelRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + stateAssert.hasSubscribedFlag().hasRequestN(1).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload payload1 = TestRequesterResponderSupport.randomPayload(allocator); + + publisher.next(payload1.retain()); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasPayload(payload1) + .hasRequestN(1) + .hasNoLeaks(); + payload1.release(); + + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + + publisher.assertMaxRequested(1); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(nextPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelRequesterFlux.handlePayload(unrequestedPayload); + + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks()) + .assertError() + .assertErrorMessage("The number of messages received exceeds the number requested"); + + publisher.assertWasCancelled(); + + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + static Stream cases() { + return Stream.of( + Arguments.arguments("complete", "sizeError"), + Arguments.arguments("complete", "refCntError"), + Arguments.arguments("complete", "onError"), + Arguments.arguments("error", "sizeError"), + Arguments.arguments("error", "refCntError"), + Arguments.arguments("error", "onError"), + Arguments.arguments("cancel", "sizeError"), + Arguments.arguments("cancel", "refCntError"), + Arguments.arguments("cancel", "onError")); + } + + @ParameterizedTest + @MethodSource("cases") + public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundSignals( + String inboundTerminationMode, String outboundTerminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final ApplicationErrorException inboundException = + new ApplicationErrorException("inboundException"); + + final ArrayList droppedErrors = new ArrayList<>(); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + Hooks.onErrorDropped(droppedErrors::add); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber> assertSubscriber = + requestChannelRequesterFlux.materialize().subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + publisher.next(requestPayload); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + + Payload responsePayload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload3 = TestRequesterResponderSupport.randomPayload(allocator); + + Payload releasedPayload = ByteBufPayload.create(Unpooled.EMPTY_BUFFER); + releasedPayload.release(); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("onError")) { + publisher.error(outboundException); + } else if (outboundTerminationMode.equals("refCntError")) { + publisher.next(releasedPayload); + } else { + publisher.next(oversizePayload); + } + }, + () -> { + requestChannelRequesterFlux.handlePayload(responsePayload1); + requestChannelRequesterFlux.handlePayload(responsePayload2); + requestChannelRequesterFlux.handlePayload(responsePayload3); + + if (inboundTerminationMode.equals("error")) { + requestChannelRequesterFlux.handleError(inboundException); + } else if (inboundTerminationMode.equals("complete")) { + requestChannelRequesterFlux.handleComplete(); + } else { + requestChannelRequesterFlux.handleCancel(); + } + }); + + ByteBuf errorFrameOrEmpty = sender.pollFrame(); + if (errorFrameOrEmpty != null) { + if (outboundTerminationMode.equals("onError")) { + FrameAssert.assertThat(errorFrameOrEmpty) + .typeOf(FrameType.ERROR) + .hasData("outboundException") + .hasNoLeaks(); + } else { + FrameAssert.assertThat(errorFrameOrEmpty).typeOf(FrameType.CANCEL).hasNoLeaks(); + } + } + + List> values = assertSubscriber.values(); + for (int j = 0; j < values.size(); j++) { + Signal signal = values.get(j); + + if (signal.isOnNext()) { + PayloadAssert.assertThat(signal.get()) + .describedAs("Expected that the next signal[%s] to have no leaks", j) + .hasNoLeaks(); + } else { + if (inboundTerminationMode.equals("error")) { + Assertions.assertThat(signal.isOnError()).isTrue(); + Throwable throwable = signal.getThrowable(); + if (throwable == inboundException) { + Assertions.assertThat(droppedErrors.get(0)) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + Assertions.assertThat(throwable).isEqualTo(inboundException); + } else { + Assertions.assertThat(droppedErrors).containsOnly(inboundException); + Assertions.assertThat(throwable) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + } else if (inboundTerminationMode.equals("complete")) { + if (signal.isOnComplete()) { + Assertions.assertThat(droppedErrors.get(0)) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } else { + Assertions.assertThat(droppedErrors).isEmpty(); + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + } else { + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } + + Assertions.assertThat(j) + .describedAs( + "Expected that the error signal[%s] is the last signal, but the last was %s", + j, values.size() - 1) + .isEqualTo(values.size() - 1); + } + } + + allocator.assertHasNoLeaks(); + droppedErrors.clear(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"complete", "cancel"}) + public void shouldRemoveItselfFromActiveStreamsWhenInboundAndOutboundAreTerminated( + String outboundTerminationMode) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(TestPublisher.Violation.DEFER_CANCELLATION); + + final RequestChannelRequesterFlux requestChannelRequesterFlux = + new RequestChannelRequesterFlux(publisher, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelRequesterFlux); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber> assertSubscriber = + requestChannelRequesterFlux.materialize().subscribeWith(AssertSubscriber.create(0)); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Integer.MAX_VALUE); + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasNoFirstFrameSentFlag(); + activeStreams.assertNoActiveStreams(); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + publisher.next(requestPayload); + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + activeStreams.assertHasStream(1, requestChannelRequesterFlux); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_CHANNEL) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelRequesterFlux.handleRequestN(Long.MAX_VALUE); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("cancel")) { + requestChannelRequesterFlux.handleCancel(); + } else { + publisher.complete(); + } + }, + requestChannelRequesterFlux::handleComplete); + + ByteBuf completeFrameOrNull = sender.pollFrame(); + if (completeFrameOrNull != null) { + FrameAssert.assertThat(completeFrameOrNull) + .hasStreamId(1) + .typeOf(FrameType.COMPLETE) + .hasNoLeaks(); + } + + assertSubscriber.assertTerminated().assertComplete(); + activeStreams.assertNoActiveStreams(); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java new file mode 100644 index 000000000..890458caf --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestChannelResponderSubscriberTest.java @@ -0,0 +1,890 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameType.*; +import static reactor.test.publisher.TestPublisher.Violation.CLEANUP_ON_TERMINATE; +import static reactor.test.publisher.TestPublisher.Violation.DEFER_CANCELLATION; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Exceptions; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Signal; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; + +public class RequestChannelResponderSubscriberTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + @ParameterizedTest + @ValueSource(strings = {"inbound", "outbound", "inboundCancel"}) + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately(String completionCase) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + assertSubscriber.request(1); + + // state machine check + stateAssert.hasSubscribedFlag().hasFirstFrameSentFlag().hasRequestN(1); + + // should not send requestN since 1 is remaining + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + // should not send requestN since 1 is remaining + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + publisher.next(TestRequesterResponderSupport.genericPayload(allocator)); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(nextPayload); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + + ByteBuf firstFragment = fragments.remove(0); + requestChannelResponderSubscriber.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollows = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestChannelResponderSubscriber.handleNext(followingFragment, hasFollows, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + if (completionCase.equals("inbound")) { + requestChannelResponderSubscriber.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + } else if (completionCase.equals("inboundCancel")) { + assertSubscriber.cancel(); + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }); + + FrameAssert.assertThat(sender.awaitFrame()).typeOf(CANCEL).hasStreamId(1).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasInboundTerminated(); + + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + } else if (completionCase.equals("outbound")) { + publisher.complete(); + FrameAssert.assertThat(sender.awaitFrame()).typeOf(FrameType.COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag() + .hasOutboundTerminated(); + + requestChannelResponderSubscriber.handleComplete(); + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isEqualTo(nextPayload).hasNoLeaks(), + p -> { + PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(); + randomPayload.release(); + }) + .assertComplete(); + } + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + assertSubscriber.request(1); + + // state machine check + stateAssert.hasSubscribedFlag().hasFirstFrameSentFlag().hasRequestN(1); + + // should not send requestN since 1 is remaining + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + // should not send requestN since 1 is remaining + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + Payload nextPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(nextPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(unrequestedPayload); + + final ByteBuf cancelErrorFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelErrorFrame) + .isNotNull() + .typeOf(ERROR) + .hasData("The number of messages received exceeds the number requested") + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isSameAs(nextPayload).hasNoLeaks()) + .assertErrorMessage("The number of messages received exceeds the number requested"); + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + Assertions.assertThat(nextPayload.refCnt()).isZero(); + Assertions.assertThat(unrequestedPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + public void failOnOverflowBeforeFirstPayloadIsSent() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + publisher.assertMaxRequested(1); + // state machine check + stateAssert.isUnsubscribed().hasRequestN(0); + + final AssertSubscriber assertSubscriber = + requestChannelResponderSubscriber.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(firstPayload.refCnt()).isOne(); + + // state machine check + stateAssert.hasSubscribedFlagOnly().hasRequestN(0); + + Payload unrequestedPayload = TestRequesterResponderSupport.genericPayload(allocator); + requestChannelResponderSubscriber.handlePayload(unrequestedPayload); + + final ByteBuf cancelErrorFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelErrorFrame) + .isNotNull() + .typeOf(ERROR) + .hasData("The number of messages received exceeds the number requested") + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber.request(1); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks()) + .assertErrorMessage("The number of messages received exceeds the number requested"); + + Assertions.assertThat(firstPayload.refCnt()).isZero(); + Assertions.assertThat(unrequestedPayload.refCnt()).isZero(); + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleCompleteWithSubscription() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> + requestChannelResponderSubscriber + .doOnNext(__ -> assertSubscriber.request(1)) + .subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleComplete()); + + stateAssert + .hasSubscribedFlag() + .hasInboundTerminated() + .hasFirstFrameSentFlag() + .hasRequestNBetween(1, 2); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isSameAs(firstPayload).hasNoLeaks()) + .assertTerminated() + .assertComplete(); + + publisher.complete(); + + if (sender.getSent().size() > 1) { + FrameAssert.assertThat(sender.awaitFrame()) + .hasStreamId(1) + .typeOf(REQUEST_N) + .hasRequestN(1) + .hasNoLeaks(); + } + FrameAssert.assertThat(sender.awaitFrame()).hasStreamId(1).typeOf(COMPLETE).hasNoLeaks(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleErrorWithSubscription() { + ApplicationErrorException applicationErrorException = new ApplicationErrorException("test"); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleError(applicationErrorException)); + + stateAssert.isTerminated(); + + publisher.assertCancelled(1); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(applicationErrorException.getClass()) + .assertErrorMessage("test"); + + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingOutboundErrorWithSubscription() { + RuntimeException exception = new RuntimeException("test"); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> publisher.error(exception)); + + stateAssert.isTerminated(); + + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .typeOf(ERROR) + .hasData("test") + .hasStreamId(1) + .hasNoLeaks(); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Outbound has terminated with an error"); + + allocator.assertHasNoLeaks(); + } + } + + @Test + public void streamShouldWorkCorrectlyWhenRacingHandleCancelWithSubscription() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + final TestPublisher publisher = TestPublisher.create(); + + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, firstPayload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestChannelResponderSubscriber); + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertHasStream(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(1); + + RaceTestUtils.race( + () -> requestChannelResponderSubscriber.subscribe(assertSubscriber), + () -> requestChannelResponderSubscriber.handleCancel()); + + stateAssert.isTerminated(); + + publisher.assertCancelled(1); + + if (!assertSubscriber.values().isEmpty()) { + assertSubscriber.assertValuesWith( + p -> PayloadAssert.assertThat(p).isSameAs(p).hasNoLeaks()); + } + + assertSubscriber + .assertTerminated() + .assertError(CancellationException.class) + .assertErrorMessage("Inbound has been canceled"); + + allocator.assertHasNoLeaks(); + } + } + + static Stream cases() { + return Stream.of( + Arguments.arguments("complete", "sizeError"), + Arguments.arguments("complete", "refCntError"), + Arguments.arguments("complete", "onError"), + Arguments.arguments("error", "sizeError"), + Arguments.arguments("error", "refCntError"), + Arguments.arguments("error", "onError"), + Arguments.arguments("cancel", "sizeError"), + Arguments.arguments("cancel", "refCntError"), + Arguments.arguments("cancel", "onError")); + } + + @ParameterizedTest + @MethodSource("cases") + public void shouldHaveEventsDeliveredSeriallyWhenOutboundErrorRacingWithInboundSignals( + String inboundTerminationMode, String outboundTerminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final ApplicationErrorException inboundException = + new ApplicationErrorException("inboundException"); + final ArrayList droppedErrors = new ArrayList<>(); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + Hooks.onErrorDropped(droppedErrors::add); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final TestPublisher publisher = + TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); + + Payload requestPayload = TestRequesterResponderSupport.randomPayload(allocator); + final RequestChannelResponderSubscriber requestChannelResponderSubscriber = + new RequestChannelResponderSubscriber(1, 1, requestPayload, activeStreams); + + activeStreams.activeStreams.put(1, requestChannelResponderSubscriber); + + publisher.subscribe(requestChannelResponderSubscriber); + final AssertSubscriber> assertSubscriber = + requestChannelResponderSubscriber + .materialize() + .subscribeWith(AssertSubscriber.create(0)); + + assertSubscriber.request(Integer.MAX_VALUE); + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasNoLeaks(); + + requestChannelResponderSubscriber.handleRequestN(Long.MAX_VALUE); + + Payload responsePayload1 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload2 = TestRequesterResponderSupport.randomPayload(allocator); + Payload responsePayload3 = TestRequesterResponderSupport.randomPayload(allocator); + + Payload releasedPayload = ByteBufPayload.create(Unpooled.EMPTY_BUFFER); + releasedPayload.release(); + + RaceTestUtils.race( + () -> { + if (outboundTerminationMode.equals("onError")) { + publisher.error(outboundException); + } else if (outboundTerminationMode.equals("refCntError")) { + publisher.next(releasedPayload); + } else { + publisher.next(oversizePayload); + } + }, + () -> { + requestChannelResponderSubscriber.handlePayload(responsePayload1); + requestChannelResponderSubscriber.handlePayload(responsePayload2); + requestChannelResponderSubscriber.handlePayload(responsePayload3); + + if (inboundTerminationMode.equals("error")) { + requestChannelResponderSubscriber.handleError(inboundException); + } else if (inboundTerminationMode.equals("complete")) { + requestChannelResponderSubscriber.handleComplete(); + } else { + requestChannelResponderSubscriber.handleCancel(); + } + }); + + ByteBuf errorFrameOrEmpty = sender.pollFrame(); + if (errorFrameOrEmpty != null) { + String message; + if (outboundTerminationMode.equals("onError")) { + message = outboundException.getMessage(); + } else if (outboundTerminationMode.equals("sizeError")) { + message = String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK); + } else { + message = "Failed to validate payload. Cause:refCnt: 0"; + } + FrameAssert.assertThat(errorFrameOrEmpty) + .typeOf(FrameType.ERROR) + .hasData(message) + .hasNoLeaks(); + } + + List> values = assertSubscriber.values(); + for (int j = 0; j < values.size(); j++) { + Signal signal = values.get(j); + + if (signal.isOnNext()) { + Payload payload = signal.get(); + if (j == 0) { + Assertions.assertThat(payload).isEqualTo(requestPayload); + } + + PayloadAssert.assertThat(payload) + .describedAs("Expected that the next signal[%s] to have no leaks", j) + .hasNoLeaks(); + } else { + if (inboundTerminationMode.equals("error")) { + Assertions.assertThat(signal.isOnError()).isTrue(); + Throwable throwable = signal.getThrowable(); + if (Exceptions.isMultiple(throwable)) { + Assertions.assertThat( + Arrays.stream(throwable.getSuppressed()).map(Throwable::getMessage)) + .containsExactlyInAnyOrder( + inboundException.getMessage(), + outboundTerminationMode.equals("onError") + ? "Outbound has terminated with an error" + : "Inbound has been canceled"); + } else { + if (throwable == inboundException) { + Assertions.assertThat(droppedErrors) + .hasSize(1) + .first() + .isExactlyInstanceOf( + outboundTerminationMode.equals("onError") + ? outboundException.getClass() + : outboundTerminationMode.equals("refCntError") + ? IllegalReferenceCountException.class + : IllegalArgumentException.class); + } else { + Assertions.assertThat(droppedErrors).containsOnly(inboundException); + } + } + } else if (inboundTerminationMode.equals("complete")) { + Assertions.assertThat(droppedErrors).isEmpty(); + if (signal.isOnError()) { + Assertions.assertThat(signal.getThrowable()) + .isExactlyInstanceOf(CancellationException.class) + .matches( + t -> + t.getMessage().equals("Inbound has been canceled") + || t.getMessage().equals("Outbound has terminated with an error")); + } + } else { + Throwable throwable = signal.getThrowable(); + if (Exceptions.isMultiple(throwable)) { + Assertions.assertThat( + Arrays.stream(throwable.getSuppressed()).map(Throwable::getMessage)) + .containsExactlyInAnyOrder( + "Inbound has been canceled", + outboundTerminationMode.equals("onError") + ? "Outbound has terminated with an error" + : "Inbound has been canceled"); + } else { + Assertions.assertThat(throwable).isExactlyInstanceOf(CancellationException.class); + } + } + + Assertions.assertThat(j) + .describedAs( + "Expected that the %s signal[%s] is the last signal, but the last was %s", + signal, j, values.get(values.size() - 1)) + .isEqualTo(values.size() - 1); + } + } + + allocator.assertHasNoLeaks(); + droppedErrors.clear(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest + @ValueSource(strings = {"onError", "sizeError", "refCntError", "cancel"}) + public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(String terminationMode) { + final RuntimeException outboundException = new RuntimeException("outboundException"); + final Payload oversizePayload = + DefaultPayload.create(new byte[FRAME_LENGTH_MASK], new byte[FRAME_LENGTH_MASK]); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final TestPublisher publisher = + TestPublisher.createNoncompliant(DEFER_CANCELLATION, CLEANUP_ON_TERMINATE); + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(2); + + Payload firstPayload = TestRequesterResponderSupport.genericPayload(allocator); + final RequestChannelResponderSubscriber requestOperator = + new RequestChannelResponderSubscriber(1, Long.MAX_VALUE, firstPayload, activeStreams); + + publisher.subscribe(requestOperator); + requestOperator.subscribe(assertSubscriber); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload responsePayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, responsePayload); + + Payload releasedPayload1 = ByteBufPayload.create(new byte[0]); + Payload releasedPayload2 = ByteBufPayload.create(new byte[0]); + releasedPayload1.release(); + releasedPayload2.release(); + + RaceTestUtils.race( + () -> { + switch (terminationMode) { + case "onError": + publisher.error(outboundException); + break; + case "sizeError": + publisher.next(oversizePayload); + break; + case "refCntError": + publisher.next(releasedPayload1); + break; + case "cancel": + default: + assertSubscriber.cancel(); + } + }, + () -> { + int lastFragmentId = fragments.size() - 1; + for (int j = 0; j < fragments.size(); j++) { + ByteBuf frame = fragments.get(j); + requestOperator.handleNext(frame, lastFragmentId != j, false); + frame.release(); + } + }); + + List values = assertSubscriber.values(); + + PayloadAssert.assertThat(values.get(0)).isEqualTo(firstPayload).hasNoLeaks(); + + if (values.size() > 1) { + Payload payload = values.get(1); + PayloadAssert.assertThat(payload).isEqualTo(responsePayload).hasNoLeaks(); + } + + if (!sender.isEmpty()) { + if (terminationMode.equals("cancel")) { + assertSubscriber.assertNotTerminated(); + } else { + assertSubscriber.assertTerminated().assertError(); + } + + final ByteBuf requstFrame = sender.awaitFrame(); + FrameAssert.assertThat(requstFrame) + .isNotNull() + .typeOf(REQUEST_N) + .hasRequestN(1) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf terminalFrame = sender.awaitFrame(); + FrameAssert.assertThat(terminalFrame) + .isNotNull() + .typeOf(terminationMode.equals("cancel") ? CANCEL : ERROR) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + + PayloadAssert.assertThat(responsePayload).hasNoLeaks(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java new file mode 100644 index 000000000..b39ac62d9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestResponseRequesterMonoTest.java @@ -0,0 +1,698 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.core.TestRequesterResponderSupport.genericPayload; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.test.StepVerifier; + +public class RequestResponseRequesterMonoTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + * + */ + + /** + * General StateMachine transition test. No Fragmentation enabled In this test we check that the + * given instance of RequestResponseMono: 1) subscribes 2) sends frame on the first request 3) + * terminates up on receiving the first signal (terminates on first next | error | next over + * reassembly | complete) + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnSubscriptionResponses") + public void frameShouldBeSentOnSubscription( + BiFunction, StepVerifier> + transformer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(RequestResponseRequesterMono.STATE, requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestResponseRequesterMono, + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(stateAssert::hasSubscribedFlagOnly) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(activeStreams::assertNoActiveStreams) + .thenRequest(1) + .then(() -> stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestResponseRequesterMono))) + .verify(); + + PayloadAssert.assertThat(payload).isReleased(); + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + stateAssert.isTerminated(); + + if (!sender.isEmpty()) { + ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream, StepVerifier>> + frameShouldBeSentOnSubscriptionResponses() { + return Stream.of( + // next case + (rrm, sv) -> + sv.then(() -> rrm.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .expectComplete(), + // complete case + (rrm, sv) -> sv.then(rrm::handleComplete).expectComplete(), + // error case + (rrm, sv) -> + sv.then(() -> rrm.handleError(new ApplicationErrorException("test"))) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(ApplicationErrorException.class)), + // fragmentation case + (rrm, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + StateAssert stateAssert = StateAssert.assertThat(rrm); + + return sv.then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFirstFragment( + rrm.allocator, + 64, + FrameType.REQUEST_RESPONSE, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()); + rrm.handleNext(followingFrame, false, false); + followingFrame.release(); + }) + .then(stateAssert::isTerminated) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + p.release(); + }) + .then(payload::release) + .expectComplete(); + }, + (rrm, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + StateAssert stateAssert = StateAssert.assertThat(rrm); + + ByteBuf[] fragments = + new ByteBuf[] { + FragmentationUtils.encodeFirstFragment( + rrm.allocator, + 64, + FrameType.REQUEST_RESPONSE, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()), + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()), + FragmentationUtils.encodeFollowsFragment( + rrm.allocator, 64, 1, false, payload.metadata(), payload.data()) + }; + + final StepVerifier stepVerifier = + sv.then( + () -> { + rrm.handleNext(fragments[0], true, false); + fragments[0].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rrm.handleNext(fragments[1], true, false); + fragments[1].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rrm.handleNext(fragments[2], true, false); + fragments[2].release(); + }) + .then( + () -> + stateAssert + .hasSubscribedFlag() + .hasRequestN(1) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then(payload::release) + .thenCancel() + .verifyLater(); + + stepVerifier.verify(); + + Assertions.assertThat(fragments).allMatch(bb -> bb.refCnt() == 0); + + return stepVerifier; + }); + } + + /** + * General StateMachine transition test. Fragmentation enabled In this test we check that the + * given instance of RequestResponseMono: 1) subscribes 2) sends fragments frames on the first + * request 3) terminates up on receiving the first signal (terminates on first next | error | next + * over reassembly | complete) + */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnSubscriptionResponses") + public void frameFragmentsShouldBeSentOnSubscription( + BiFunction, StepVerifier> + transformer) { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestResponseRequesterMono, + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(stateAssert::hasSubscribedFlagOnly) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(activeStreams::assertNoActiveStreams) + .thenRequest(1) + .then(() -> stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestResponseRequesterMono))) + .verify(); + + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOf(metadata, 52)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_RESPONSE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET_WITH_METADATA) // 64 - 6 (frame headers) - 3 (encoded metadata + // length) - 3 frame length + .hasMetadata(Arrays.copyOfRange(metadata, 52, 65)) + .hasData(Arrays.copyOf(data, 39)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize( + 64 - FRAME_OFFSET) // 64 - 6 (frame headers) - 3 frame length (no metadata - no length) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 39, 94)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(35) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 94, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General StateMachine transition test. Ensures that no fragment is sent if mono was cancelled + * before any requests + */ + @Test + public void shouldBeNoOpsOnCancel() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = ByteBufPayload.create("testData", "testMetadata"); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(() -> stateAssert.hasSubscribedFlagOnly()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenCancel() + .verify(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload is an invalid one. + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestResponseRequesterMono); + + stateAssert.isTerminated(); + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload was release in the middle of interaction. + * Fragmentation is disabled + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final Payload payload = ByteBufPayload.create(""); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(payload::release) + .thenRequest(1) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrate to + * the terminated in case the given payload was release in the middle of interaction. + * Fragmentation is enabled + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation() { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestResponseRequesterMono, 0) + .expectSubscription() + .then(payload::release) + .thenRequest(1) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * General state machine test Ensures that a Subscriber receives error signal and state migrates + * to the terminated in case the given payload is too big with disabled fragmentation + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + ; + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + monoConsumer.accept(requestResponseRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that error check happens exactly before frame sent. This cases ensures that in case no + * lease / other external errors appeared, the local subscriber received the same one. No frames + * should be sent + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(new RuntimeException("test")); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestResponseRequesterMono); + + activeStreams.assertNoActiveStreams(); + stateAssert.isUnsubscribed(); + + monoConsumer.accept(requestResponseRequesterMono); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then(() -> StateAssert.assertThat(s).hasSubscribedFlagOnly()) + .thenRequest(1) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + requestResponseRequesterMono -> + Assertions.assertThatThrownBy(requestResponseRequesterMono::block) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = genericPayload(allocator); + + final RequestResponseRequesterMono requestResponseRequesterMono = + new RequestResponseRequesterMono(payload, activeStreams); + + Assertions.assertThat(Scannable.from(requestResponseRequesterMono).name()) + .isEqualTo("source(RequestResponseMono)"); + requestResponseRequesterMono.cancel(); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java new file mode 100644 index 000000000..8702d1a80 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequestStreamRequesterFluxTest.java @@ -0,0 +1,1227 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA; +import static io.rsocket.core.FragmentationUtils.FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N; +import static io.rsocket.core.PayloadValidationUtils.INVALID_PAYLOAD_ERROR_MESSAGE; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.test.StepVerifier; + +public class RequestStreamRequesterFluxTest { + + @BeforeAll + public static void setUp() { + StepVerifier.setDefaultTimeout(Duration.ofSeconds(2)); + } + + /* + * +-------------------------------+ + * | General Test Cases | + * +-------------------------------+ + */ + + /** + * State Machine check. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
+   * REQUESTED(0) -> REQUESTED(1) -> REQUESTED(0)
+   * REQUESTED(0) -> REQUESTED(MAX)
+   * REQUESTED(MAX) -> REQUESTED(MAX) && REASSEMBLY (extra flag enabled which indicates
+   * reassembly)
+   * REQUESTED(MAX) && REASSEMBLY -> TERMINATED
+   * 
+ */ + @Test + public void requestNFrameShouldBeSentOnSubscriptionAndThenSeparately() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check. Request N Frame should sent so request field should be 0 + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(2).hasFirstFrameSentFlag(); + + assertSubscriber.request(Long.MAX_VALUE); + final ByteBuf requestMaxNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestMaxNFrame) + .isNotNull() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_N) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + assertSubscriber.request(6); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, randomPayload); + ByteBuf firstFragment = fragments.remove(0); + requestStreamRequesterFlux.handleNext(firstFragment, true, false); + firstFragment.release(); + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag(); + + for (int i = 0; i < fragments.size(); i++) { + boolean hasFollowing = i != fragments.size() - 1; + ByteBuf followingFragment = fragments.get(i); + + requestStreamRequesterFlux.handleNext(followingFragment, hasFollowing, false); + followingFragment.release(); + } + + // state machine check + stateAssert + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag(); + + Payload finalRandomPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(finalRandomPayload); + requestStreamRequesterFlux.handleComplete(); + + assertSubscriber + .assertValuesWith( + p -> PayloadAssert.assertThat(p).isEqualTo(randomPayload).hasNoLeaks(), + p -> PayloadAssert.assertThat(p).isEqualTo(finalRandomPayload).hasNoLeaks()) + .assertComplete(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * State Machine check. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(MAX)
+   * REQUESTED(MAX) -> TERMINATED
+   * 
+ */ + @Test + public void requestNFrameShouldBeSentExactlyOnceIfItIsMaxAllowed() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(Long.MAX_VALUE / 2 + 1); + + // state machine check + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(Integer.MAX_VALUE) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + assertSubscriber.request(1); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + + stateAssert.hasSubscribedFlag().hasRequestN(Integer.MAX_VALUE).hasFirstFrameSentFlag(); + + requestStreamRequesterFlux.handlePayload(EmptyPayload.INSTANCE); + requestStreamRequesterFlux.handleComplete(); + + assertSubscriber.assertValues(EmptyPayload.INSTANCE).assertComplete(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + /** + * State Machine check. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
+   * 
+ * + * And then for the following cases: + * + *
+   * [0]: REQUESTED(0) -> REQUESTED(MAX) (with onNext and few extra request(1) which should not
+   * affect state anyhow and should not sent any extra frames)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [1]: REQUESTED(0) -> REQUESTED(MAX) (with onComplete rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [2]: REQUESTED(0) -> REQUESTED(MAX) (with onError rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [3]: REQUESTED(0) -> REASSEMBLY
+   *      REASSEMBLY -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [4]: REQUESTED(0) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> TERMINATED (because of cancel() invocation)
+   * 
+ */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnFirstRequestResponses") + public void frameShouldBeSentOnFirstRequest( + BiFunction, StepVerifier> + transformer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestStreamRequesterFlux, + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag()) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestStreamRequesterFlux))) + .verify(); + + Assertions.assertThat(payload.refCnt()).isZero(); + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + // state machine check + stateAssert.isTerminated(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + allocator.assertHasNoLeaks(); + } + + static Stream, StepVerifier>> + frameShouldBeSentOnFirstRequestResponses() { + return Stream.of( + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(), + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .thenRequest(1L) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(), + (rsf, sv) -> + sv.then(() -> rsf.handlePayload(EmptyPayload.INSTANCE)) + .expectNext(EmptyPayload.INSTANCE) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag()) + .then(() -> rsf.handleError(new ApplicationErrorException("test"))) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .thenRequest(1L) + .thenRequest(1L) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(ApplicationErrorException.class)), + (rsf, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + final Payload payload2 = ByteBufPayload.create(data, metadata); + + return sv.then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFirstFragment( + rsf.allocator, + 64, + FrameType.NEXT, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, true, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(1) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + final ByteBuf followingFrame = + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()); + rsf.handleNext(followingFrame, false, false); + followingFrame.release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag()) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }) + .then(payload::release) + .then(() -> rsf.handlePayload(payload2)) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag()) + .assertNext( + p -> { + Assertions.assertThat(p.data()).isEqualTo(Unpooled.wrappedBuffer(data)); + + Assertions.assertThat(p.metadata()).isEqualTo(Unpooled.wrappedBuffer(metadata)); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasRequestN(Integer.MAX_VALUE) + .hasSubscribedFlag() + .hasFirstFrameSentFlag()) + .then(rsf::handleComplete) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf).isTerminated()) + .expectComplete(); + }, + (rsf, sv) -> { + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload0 = ByteBufPayload.create(data, metadata); + final Payload payload = ByteBufPayload.create(data, metadata); + + ByteBuf[] fragments = + new ByteBuf[] { + FragmentationUtils.encodeFirstFragment( + rsf.allocator, + 64, + FrameType.NEXT, + 1, + payload.hasMetadata(), + payload.metadata(), + payload.data()), + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()), + FragmentationUtils.encodeFollowsFragment( + rsf.allocator, 64, 1, false, payload.metadata(), payload.data()) + }; + + final StepVerifier stepVerifier = + sv.then(() -> rsf.handlePayload(payload0)) + .assertNext(p -> PayloadAssert.assertThat(p).isEqualTo(payload0).hasNoLeaks()) + .thenRequest(Long.MAX_VALUE) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasNoReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[0], true, false); + fragments[0].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[1], true, false); + fragments[1].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then( + () -> { + rsf.handleNext(fragments[2], true, false); + fragments[2].release(); + }) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(rsf) + .hasSubscribedFlag() + .hasRequestN(Integer.MAX_VALUE) + .hasFirstFrameSentFlag() + .hasReassemblingFlag()) + .then(payload::release) + .thenCancel() + .verifyLater(); + + stepVerifier.verify(); + // state machine check + StateAssert.assertThat(rsf).isTerminated(); + + Assertions.assertThat(fragments).allMatch(bb -> bb.refCnt() == 0); + + return stepVerifier; + }); + } + + /** + * State Machine check with fragmentation of the first payload. Ensure migration from + * + *
+   * UNSUBSCRIBED -> SUBSCRIBED
+   * SUBSCRIBED -> REQUESTED(1) -> REQUESTED(0)
+   * 
+ * + * And then for the following cases: + * + *
+   * [0]: REQUESTED(0) -> REQUESTED(MAX) (with onNext and few extra request(1) which should not
+   * affect state anyhow and should not sent any extra frames)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [1]: REQUESTED(0) -> REQUESTED(MAX) (with onComplete rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [2]: REQUESTED(0) -> REQUESTED(MAX) (with onError rightaway)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [3]: REQUESTED(0) -> REASSEMBLY
+   *      REASSEMBLY -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> TERMINATED
+   *
+   * [4]: REQUESTED(0) -> REQUESTED(MAX)
+   *      REQUESTED(MAX) -> REASSEMBLY && REQUESTED(MAX)
+   *      REASSEMBLY && REQUESTED(MAX) -> TERMINATED (because of cancel() invocation)
+   * 
+ */ + @ParameterizedTest + @MethodSource("frameShouldBeSentOnFirstRequestResponses") + public void frameFragmentsShouldBeSentOnFirstRequest( + BiFunction, StepVerifier> + transformer) { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + transformer + .apply( + requestStreamRequesterFlux, + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then(() -> Assertions.assertThat(payload.refCnt()).isOne()) + .then(() -> activeStreams.assertNoActiveStreams()) + .thenRequest(1) + .then(() -> Assertions.assertThat(payload.refCnt()).isZero()) + .then(() -> activeStreams.assertHasStream(1, requestStreamRequesterFlux))) + .verify(); + + // should not add anything to map + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(payload.refCnt()).isZero(); + + final ByteBuf frameFragment1 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment1) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N) + // InitialRequestN size + .hasMetadata(Arrays.copyOf(metadata, 64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N)) + .hasData(Unpooled.EMPTY_BUFFER) + .hasFragmentsFollow() + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment2 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment2) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET_WITH_METADATA) + .hasMetadata( + Arrays.copyOfRange(metadata, 64 - FRAME_OFFSET_WITH_METADATA_AND_INITIAL_REQUEST_N, 65)) + .hasData(Arrays.copyOf(data, 35)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment3 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment3) + .isNotNull() + .hasPayloadSize(64 - FRAME_OFFSET) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 35, 35 + 55)) + .hasFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf frameFragment4 = sender.awaitFrame(); + FrameAssert.assertThat(frameFragment4) + .isNotNull() + .hasPayloadSize(39) + .hasNoMetadata() + .hasData(Arrays.copyOfRange(data, 90, 129)) + .hasNoFragmentsFollow() + .typeOf(FrameType.NEXT) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf requestNFrame = sender.awaitFrame(); + FrameAssert.assertThat(requestNFrame) + .isNotNull() + .typeOf(FrameType.REQUEST_N) + .hasRequestN(Integer.MAX_VALUE) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (!sender.isEmpty()) { + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + } + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Case which ensures that if Payload has incorrect refCnt, the flux ends up with an appropriate + * error + */ + @ParameterizedTest + @MethodSource("shouldErrorOnIncorrectRefCntInGivenPayloadSource") + public void shouldErrorOnIncorrectRefCntInGivenPayload( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = ByteBufPayload.create(""); + payload.release(); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorOnIncorrectRefCntInGivenPayloadSource() { + return Stream.of( + (s) -> + StepVerifier.create(s) + .expectSubscription() + .expectError(IllegalReferenceCountException.class) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .isInstanceOf(IllegalReferenceCountException.class)); + } + + /** + * Ensures that if Payload is release right after the subscription, the first request will exponse + * the error immediatelly and no frame will be sent to the remote party + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhase() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final Payload payload = ByteBufPayload.create(""); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(payload::release) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.isTerminated()) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Ensures that if Payload is release right after the subscription, the first request will expose + * the error immediately and no frame will be sent to the remote party + */ + @Test + public void shouldErrorOnIncorrectRefCntInGivenPayloadLatePhaseWithFragmentation() { + final int mtu = 64; + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(mtu); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[65]; + final byte[] data = new byte[129]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + StepVerifier.create(requestStreamRequesterFlux, 0) + .expectSubscription() + .then( + () -> + // state machine check + stateAssert.hasSubscribedFlagOnly()) + .then(payload::release) + .thenRequest(1) + .then( + () -> + // state machine check + stateAssert.isTerminated()) + .expectError(IllegalReferenceCountException.class) + .verify(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + /** + * Ensures that if the given payload is exits 16mb size with disabled fragmentation, than the + * appropriate validation happens and a corresponding error will be propagagted to the subscriber + */ + @ParameterizedTest + @MethodSource("shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource") + public void shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabled( + Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + + final byte[] metadata = new byte[FRAME_LENGTH_MASK]; + final byte[] data = new byte[FRAME_LENGTH_MASK]; + ThreadLocalRandom.current().nextBytes(metadata); + ThreadLocalRandom.current().nextBytes(data); + + final Payload payload = ByteBufPayload.create(data, metadata); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(sender.isEmpty()).isTrue(); + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + static Stream> + shouldErrorIfFragmentExitsAllowanceIfFragmentationDisabledSource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then( + () -> + // state machine check + StateAssert.assertThat(s).isTerminated()) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage( + String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .hasMessage(String.format(INVALID_PAYLOAD_ERROR_MESSAGE, FRAME_LENGTH_MASK)) + .isInstanceOf(IllegalArgumentException.class)); + } + + /** + * Ensures that the interactions check and respect rsocket availability (such as leasing) and + * propagate an error to the final subscriber. No frame should be sent. Check should happens + * exactly on the first request. + */ + @ParameterizedTest + @MethodSource("shouldErrorIfNoAvailabilitySource") + public void shouldErrorIfNoAvailability(Consumer monoConsumer) { + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(new RuntimeException("test")); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + monoConsumer.accept(requestStreamRequesterFlux); + + Assertions.assertThat(payload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + allocator.assertHasNoLeaks(); + } + + static Stream> shouldErrorIfNoAvailabilitySource() { + return Stream.of( + (s) -> + StepVerifier.create(s, 0) + .expectSubscription() + .then( + () -> + // state machine check + StateAssert.assertThat(s).hasSubscribedFlagOnly()) + .thenRequest(1) + .then( + () -> + // state machine check + StateAssert.assertThat(s).isTerminated()) + .consumeErrorWith( + t -> + Assertions.assertThat(t) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)) + .verify(), + requestStreamRequesterFlux -> + Assertions.assertThatThrownBy(requestStreamRequesterFlux::blockLast) + .hasMessage("test") + .isInstanceOf(RuntimeException.class)); + } + + @Test + public void failOnOverflow() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final TestDuplexConnection sender = activeStreams.getDuplexConnection(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + final StateAssert stateAssert = + StateAssert.assertThat(requestStreamRequesterFlux); + + // state machine check + + stateAssert.isUnsubscribed(); + activeStreams.assertNoActiveStreams(); + + final AssertSubscriber assertSubscriber = + requestStreamRequesterFlux.subscribeWith(AssertSubscriber.create(0)); + Assertions.assertThat(payload.refCnt()).isOne(); + activeStreams.assertNoActiveStreams(); + // state machine check + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertHasStream(1, requestStreamRequesterFlux); + + // state machine check + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf frame = sender.awaitFrame(); + FrameAssert.assertThat(frame) + .isNotNull() + .hasPayloadSize( + "testData".getBytes(CharsetUtil.UTF_8).length + + "testMetadata".getBytes(CharsetUtil.UTF_8).length) + .hasMetadata("testMetadata") + .hasData("testData") + .hasNoFragmentsFollow() + .hasRequestN(1) + .typeOf(FrameType.REQUEST_STREAM) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + Payload requestedPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(requestedPayload); + + Payload unrequestedPayload = TestRequesterResponderSupport.randomPayload(allocator); + requestStreamRequesterFlux.handlePayload(unrequestedPayload); + + final ByteBuf cancelFrame = sender.awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + assertSubscriber + .assertValuesWith(p -> PayloadAssert.assertThat(p).isEqualTo(requestedPayload).hasNoLeaks()) + .assertError() + .assertErrorMessage("The number of messages received exceeds the number requested"); + + PayloadAssert.assertThat(requestedPayload).isReleased(); + PayloadAssert.assertThat(unrequestedPayload).isReleased(); + + Assertions.assertThat(payload.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + + Assertions.assertThat(sender.isEmpty()).isTrue(); + + // state machine check + stateAssert.isTerminated(); + allocator.assertHasNoLeaks(); + } + + @Test + public void checkName() { + final TestRequesterResponderSupport activeStreams = TestRequesterResponderSupport.client(); + final LeaksTrackingByteBufAllocator allocator = activeStreams.getAllocator(); + final Payload payload = TestRequesterResponderSupport.genericPayload(allocator); + + final RequestStreamRequesterFlux requestStreamRequesterFlux = + new RequestStreamRequesterFlux(payload, activeStreams); + + Assertions.assertThat(Scannable.from(requestStreamRequesterFlux).name()) + .isEqualTo("source(RequestStreamFlux)"); + requestStreamRequesterFlux.cancel(); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java new file mode 100644 index 000000000..06d050f6f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/RequesterOperatorsRacingTest.java @@ -0,0 +1,790 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameType.COMPLETE; +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_N; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.CharsetUtil; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.RaceTestConstants; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.RequestStreamFrameCodec; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Supplier; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +@SuppressWarnings("ALL") +public class RequesterOperatorsRacingTest { + + interface Scenario { + FrameType requestType(); + + Publisher requestOperator( + Supplier payloadsSupplier, RequesterResponderSupport requesterResponderSupport); + } + + static Stream scenarios() { + return Stream.of( + new Scenario() { + @Override + public FrameType requestType() { + return METADATA_PUSH; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new MetadataPushRequesterMono(payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return MetadataPushRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_FNF; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new FireAndForgetRequesterMono( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return FireAndForgetRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_RESPONSE; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestResponseRequesterMono( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestResponseRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_STREAM; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestStreamRequesterFlux( + payloadsSupplier.get(), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestStreamRequesterFlux.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_CHANNEL; + } + + @Override + public Publisher requestOperator( + Supplier payloadsSupplier, + RequesterResponderSupport requesterResponderSupport) { + return new RequestChannelRequesterFlux( + Flux.generate(s -> s.next(payloadsSupplier.get())), requesterResponderSupport); + } + + @Override + public String toString() { + return RequestChannelRequesterFlux.class.getSimpleName(); + } + }); + } + + /* + * +--------------------------------+ + * | Racing Test Cases | + * +--------------------------------+ + */ + + /** Ensures single subscription happens in case of racing */ + @ParameterizedTest(name = "Should subscribe exactly once to {0}") + @MethodSource("scenarios") + public void shouldSubscribeExactlyOnce(Scenario scenario) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport requesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> + TestRequesterResponderSupport.genericPayload( + requesterResponderSupport.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, requesterResponderSupport); + + StepVerifier stepVerifier = + StepVerifier.create(requesterResponderSupport.getDuplexConnection().getSentAsPublisher()) + .assertNext( + frame -> { + FrameAssert frameAssert = + FrameAssert.assertThat(frame) + .isNotNull() + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()); + if (scenario.requestType() == METADATA_PUSH) { + frameAssert + .hasStreamIdZero() + .hasPayloadSize( + TestRequesterResponderSupport.METADATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT); + } else { + frameAssert + .hasClientSideStreamId() + .hasStreamId(1) + .hasPayloadSize( + TestRequesterResponderSupport.METADATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length + + TestRequesterResponderSupport.DATA_CONTENT.getBytes( + CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT); + } + frameAssert.hasNoLeaks(); + + if (requestOperator instanceof FrameHandler) { + ((FrameHandler) requestOperator).handleComplete(); + if (scenario.requestType() == REQUEST_CHANNEL) { + ((FrameHandler) requestOperator).handleCancel(); + } + } + }) + .thenCancel() + .verifyLater(); + + Assertions.assertThatThrownBy( + () -> + RaceTestUtils.race( + () -> { + AssertSubscriber subscriber = new AssertSubscriber<>(); + requestOperator.subscribe(subscriber); + subscriber.await().assertTerminated().assertNoError(); + }, + () -> { + AssertSubscriber subscriber = new AssertSubscriber<>(); + requestOperator.subscribe(subscriber); + subscriber.await().assertTerminated().assertNoError(); + })) + .matches( + t -> { + Assertions.assertThat(t).hasMessageContaining("allows only a single Subscriber"); + return true; + }); + + stepVerifier.verify(Duration.ofSeconds(1)); + requesterResponderSupport.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_START, + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .assertNext( + event -> + Assertions.assertThat(event.eventType) + .isIn( + TestRequestInterceptor.EventType.ON_COMPLETE, + TestRequestInterceptor.EventType.ON_REJECT)) + .expectNothing(); + } + } + } + + /** Ensures single frame is sent only once racing between requests */ + @ParameterizedTest(name = "{0} should sent requestFrame exactly once if request(n) is racing") + @MethodSource("scenarios") + public void shouldSentRequestFrameOnceInCaseOfRequestRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + (Publisher) scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + requestOperator.subscribe(assertSubscriber); + + RaceTestUtils.race(() -> assertSubscriber.request(1), () -> assertSubscriber.request(1)); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + + if (scenario.requestType().hasInitialRequestN()) { + if (RequestStreamFrameCodec.initialRequestN(sentFrame) == 1) { + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .isNotNull() + .hasStreamId(1) + .hasRequestN(1) + .typeOf(REQUEST_N) + .hasNoLeaks(); + } else { + Assertions.assertThat(RequestStreamFrameCodec.initialRequestN(sentFrame)).isEqualTo(2); + } + } + + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + ((RequesterFrameHandler) requestOperator).handlePayload(response); + ((RequesterFrameHandler) requestOperator).handleComplete(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + ((CoreSubscriber) requestOperator).onComplete(); + FrameAssert.assertThat(activeStreams.getDuplexConnection().awaitFrame()) + .typeOf(COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + } + + assertSubscriber + .assertTerminated() + .assertValuesWith( + p -> { + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + }); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + if (scenario.requestType() != METADATA_PUSH) { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + } + } + + /** + * Ensures that no ByteBuf is leaked if reassembly is starting and cancel is happening at the same + * time + */ + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnReassemblyAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + (Publisher) scenario.requestOperator(payloadSupplier, activeStreams); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(1); + + requestOperator.subscribe(assertSubscriber); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload responsePayload = + TestRequesterResponderSupport.randomPayload(activeStreams.getAllocator()); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments( + activeStreams.getAllocator(), mtu, responsePayload); + RaceTestUtils.race( + assertSubscriber::cancel, + () -> { + FrameHandler frameHandler = (FrameHandler) requestOperator; + int lastFragmentId = fragments.size() - 1; + for (int j = 0; j < fragments.size(); j++) { + ByteBuf frame = fragments.get(j); + frameHandler.handleNext(frame, lastFragmentId != j, lastFragmentId == j); + frame.release(); + } + }); + + List values = assertSubscriber.values(); + if (!values.isEmpty()) { + Assertions.assertThat(values) + .hasSize(1) + .first() + .matches( + p -> { + Assertions.assertThat(p.sliceData()) + .matches(bb -> ByteBufUtil.equals(bb, responsePayload.sliceData())); + Assertions.assertThat(p.hasMetadata()).isEqualTo(responsePayload.hasMetadata()); + Assertions.assertThat(p.sliceMetadata()) + .matches(bb -> ByteBufUtil.equals(bb, responsePayload.sliceMetadata())); + Assertions.assertThat(p.release()).isTrue(); + Assertions.assertThat(p.refCnt()).isZero(); + return true; + }); + } + + if (!activeStreams.getDuplexConnection().isEmpty()) { + if (scenario.requestType() != REQUEST_CHANNEL) { + assertSubscriber.assertNotTerminated(); + } + + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + + Assertions.assertThat(responsePayload.release()).isTrue(); + Assertions.assertThat(responsePayload.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** + * Ensures that in case of racing between next element and cancel we will not have any memory + * leaks + */ + @ParameterizedTest(name = "Should have no leaks when {0} is canceled during reassembly") + @MethodSource("scenarios") + public void shouldHaveNoLeaksOnNextAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + requestOperator.subscribe((AssertSubscriber) assertSubscriber); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + RaceTestUtils.race( + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handlePayload(response)); + + assertSubscriber.values().forEach(Payload::release); + Assertions.assertThat(response.refCnt()).isZero(); + + activeStreams.assertNoActiveStreams(); + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + assertSubscriber.assertTerminated(); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** + * Ensures that in case we have element reassembling and then it happens the remote sends + * (errorFrame) and downstream subscriber sends cancel() and we have racing between onError and + * cancel we will not have any memory leaks + */ + @ParameterizedTest + @MethodSource("scenarios") + public void shouldHaveNoUnexpectedErrorDuringOnErrorAndCancelRacing(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + boolean[] withReassemblyOptions = new boolean[] {true, false}; + final ArrayList droppedErrors = new ArrayList<>(); + Hooks.onErrorDropped(droppedErrors::add); + + try { + for (boolean withReassembly : withReassemblyOptions) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = + scenario.requestOperator(payloadSupplier, activeStreams); + + final StateAssert stateAssert; + if (requestOperator instanceof RequestResponseRequesterMono) { + stateAssert = StateAssert.assertThat((RequestResponseRequesterMono) requestOperator); + } else if (requestOperator instanceof RequestStreamRequesterFlux) { + stateAssert = StateAssert.assertThat((RequestStreamRequesterFlux) requestOperator); + } else { + stateAssert = StateAssert.assertThat((RequestChannelRequesterFlux) requestOperator); + } + + stateAssert.isUnsubscribed(); + final AssertSubscriber assertSubscriber = AssertSubscriber.create(0); + + requestOperator.subscribe((AssertSubscriber) assertSubscriber); + + stateAssert.hasSubscribedFlagOnly(); + + assertSubscriber.request(1); + + stateAssert.hasSubscribedFlag().hasRequestN(1).hasFirstFrameSentFlag(); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasPayloadSize( + TestRequesterResponderSupport.DATA_CONTENT.getBytes(CharsetUtil.UTF_8).length + + TestRequesterResponderSupport.METADATA_CONTENT.getBytes(CharsetUtil.UTF_8) + .length) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + if (withReassembly) { + final ByteBuf fragmentBuf = + activeStreams.getAllocator().buffer().writeBytes(new byte[] {1, 2, 3}); + ((RequesterFrameHandler) requestOperator).handleNext(fragmentBuf, true, false); + // mimic frameHandler behaviour + fragmentBuf.release(); + } + + final RuntimeException testException = new RuntimeException("test"); + RaceTestUtils.race( + ((Subscription) requestOperator)::cancel, + () -> ((RequesterFrameHandler) requestOperator).handleError(testException)); + + activeStreams.assertNoActiveStreams(); + stateAssert.isTerminated(); + + final boolean isEmpty = activeStreams.getDuplexConnection().isEmpty(); + if (!isEmpty) { + final ByteBuf cancellationFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancellationFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + Assertions.assertThat(droppedErrors).containsExactly(testException); + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } else { + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnError(1) + .expectNothing(); + + assertSubscriber.assertTerminated().assertErrorMessage("test"); + } + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + + stateAssert.isTerminated(); + droppedErrors.clear(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + /** + * Ensures that in case of racing between first request and cancel does not going to introduce + * leaks.
+ *
+ * + *

Please note, first request may or may not happen so in case it happened before cancellation + * signal we have to observe + * + *

    + *
  • RequestResponseFrame + *
  • CancellationFrame + *
+ * + *

exactly in that order + * + *

Ensures full serialization of outgoing signal (frames) + */ + @ParameterizedTest + @MethodSource("scenarios") + public void shouldBeConsistentInCaseOfRacingOfCancellationAndRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requestOperator = scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requestOperator.subscribe((AssertSubscriber) assertSubscriber); + + RaceTestUtils.race(() -> assertSubscriber.cancel(), () -> assertSubscriber.request(1)); + + if (!activeStreams.getDuplexConnection().isEmpty()) { + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .typeOf(scenario.requestType()) + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasNoFragmentsFollow() + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + } + + ((RequesterFrameHandler) requestOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); + + Assertions.assertThat(response.refCnt()).isZero(); + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } + + /** Ensures that CancelFrame is sent exactly once in case of racing between cancel() methods */ + @ParameterizedTest + @MethodSource("scenarios") + public void shouldSentCancelFrameExactlyOnce(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()) + .isIn(REQUEST_RESPONSE, REQUEST_STREAM, REQUEST_CHANNEL); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequesterResponderSupport activeStreams = + TestRequesterResponderSupport.client(testRequestInterceptor); + final Supplier payloadSupplier = + () -> TestRequesterResponderSupport.genericPayload(activeStreams.getAllocator()); + + final Publisher requesterOperator = + scenario.requestOperator(payloadSupplier, activeStreams); + + Payload response = ByteBufPayload.create("test", "test"); + + final AssertSubscriber assertSubscriber = new AssertSubscriber<>(0); + + requesterOperator.subscribe((AssertSubscriber) assertSubscriber); + + assertSubscriber.request(1); + + final ByteBuf sentFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(sentFrame) + .isNotNull() + .hasNoFragmentsFollow() + .typeOf(scenario.requestType()) + .hasClientSideStreamId() + .hasMetadata(TestRequesterResponderSupport.METADATA_CONTENT) + .hasData(TestRequesterResponderSupport.DATA_CONTENT) + .hasStreamId(1) + .hasNoLeaks(); + + RaceTestUtils.race( + ((Subscription) requesterOperator)::cancel, ((Subscription) requesterOperator)::cancel); + + final ByteBuf cancelFrame = activeStreams.getDuplexConnection().awaitFrame(); + FrameAssert.assertThat(cancelFrame) + .isNotNull() + .typeOf(FrameType.CANCEL) + .hasClientSideStreamId() + .hasStreamId(1) + .hasNoLeaks(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + + activeStreams.assertNoActiveStreams(); + + ((RequesterFrameHandler) requesterOperator).handlePayload(response); + assertSubscriber.values().forEach(Payload::release); + Assertions.assertThat(response.refCnt()).isZero(); + + ((RequesterFrameHandler) requesterOperator).handleComplete(); + assertSubscriber.assertNotTerminated(); + + activeStreams.assertNoActiveStreams(); + Assertions.assertThat(activeStreams.getDuplexConnection().isEmpty()).isTrue(); + activeStreams.getAllocator().assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java b/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java new file mode 100644 index 000000000..382240c4a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ResolvingOperatorTests.java @@ -0,0 +1,1030 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Queue; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Predicate; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Condition; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +public class ResolvingOperatorTests { + + @Test + public void shouldExpireValueOnRacingDisposeAndComplete() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final int index = i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingResolution() + .thenAddObserver(consumer) + .assertPendingSubscribers(1) + .assertPendingResolution() + .then(self -> RaceTestUtils.race(() -> self.complete("value" + index), self::dispose)) + .assertDisposeCalled(1) + .assertExpiredExactly("value" + index) + .ifResolvedAssertEqual("value" + index) + .assertIsDisposed(); + + subscriber.assertTerminated(); + + if (!subscriber.errors().isEmpty()) { + Assertions.assertThat(subscriber.errors().get(0)) + .isInstanceOf(CancellationException.class) + .hasMessage("Disposed"); + + } else { + Assertions.assertThat(subscriber.values()).containsExactly("value" + i); + } + } + } + + @Test + public void shouldNotifyAllTheSubscribersUnderRacingBetweenSubscribeAndComplete() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .then( + self -> { + RaceTestUtils.race(() -> self.complete(valueToSend), () -> self.observe(consumer)); + + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); + }) + .assertDisposeCalled(0) + .assertReceivedExactly(valueToSend) + .assertNothingExpired() + .thenAddObserver(consumer2) + .assertPendingSubscribers(0); + + subscriber2.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidate() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + final String valueToSend2 = "value2" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .thenAddObserver(consumer) + .then( + self -> { + self.complete(valueToSend); + + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); + }) + .assertReceivedExactly(valueToSend) + .then( + self -> + RaceTestUtils.race( + self::invalidate, + () -> { + self.observe(consumer2); + if (!subscriber2.isTerminated()) { + self.complete(valueToSend2); + } + })) + .then( + self -> { + if (self.isPending()) { + self.assertReceivedExactly(valueToSend); + } else { + self.assertReceivedExactly(valueToSend, valueToSend2); + } + }) + .assertExpiredExactly(valueToSend) + .assertPendingSubscribers(0) + .assertDisposeCalled(0) + .then( + self -> + subscriber2 + .await(Duration.ofMillis(100)) + .assertValueCount(1) + .assertValuesWith( + v -> { + if (self.subscribers == ResolvingOperator.READY) { + Assertions.assertThat(v).isEqualTo(valueToSend2); + } else { + Assertions.assertThat(v).isEqualTo(valueToSend); + } + }) + .assertComplete()); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfSubscribeIsRacingWithInvalidates() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + final String valueToSend2 = "value_to_possibly_expire" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .thenAddObserver(consumer) + .then( + self -> { + self.complete(valueToSend); + + subscriber.await(Duration.ofMillis(100)).assertValues(valueToSend).assertComplete(); + }) + .assertReceivedExactly(valueToSend) + .then( + self -> + RaceTestUtils.race( + self::invalidate, + self::invalidate, + () -> { + self.observe(consumer2); + if (!subscriber2.isTerminated()) { + self.complete(valueToSend2); + } + })) + .then( + self -> { + if (!self.isPending()) { + self.assertReceivedExactly(valueToSend, valueToSend2); + } else { + if (self.received.size() > 1) { + self.assertReceivedExactly(valueToSend, valueToSend2); + } else { + self.assertReceivedExactly(valueToSend); + } + } + + Assertions.assertThat(self.expired) + .haveAtMost( + 2, + new Condition<>( + new Predicate() { + int time = 0; + + @Override + public boolean test(Object s) { + if (time++ == 0) { + return valueToSend.equals(s); + } else { + return valueToSend2.equals(s); + } + } + }, + "should matches one of the given values")); + }) + .assertPendingSubscribers(0) + .assertDisposeCalled(0) + .then( + self -> + subscriber2 + .await(Duration.ofMillis(100)) + .assertValueCount(1) + .assertValuesWith( + v -> { + if (self.subscribers == ResolvingOperator.READY) { + Assertions.assertThat(v).isEqualTo(valueToSend2); + } else { + Assertions.assertThat(v).isIn(valueToSend, valueToSend2); + } + }) + .assertComplete()); + } + } + + @Test + public void shouldNotExpireNewlyResolvedValueIfBlockIsRacingWithInvalidate() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + final String valueToSend2 = "value2" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .thenAddObserver(consumer) + .then( + self -> { + self.complete(valueToSend); + + subscriber.await(Duration.ofMillis(10)).assertValues(valueToSend).assertComplete(); + }) + .assertReceivedExactly(valueToSend) + .then( + self -> + RaceTestUtils.race( + () -> + Assertions.assertThat(self.block(null)) + .matches((v) -> v.equals(valueToSend) || v.equals(valueToSend2)), + self::invalidate, + () -> { + for (; ; ) { + if (self.subscribers != ResolvingOperator.READY) { + self.complete(valueToSend2); + break; + } + } + })) + .then( + self -> { + if (self.isPending()) { + self.assertReceivedExactly(valueToSend); + } else { + self.assertReceivedExactly(valueToSend, valueToSend2); + } + }) + .assertExpiredExactly(valueToSend) + .assertPendingSubscribers(0) + .assertDisposeCalled(0); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribers() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .then( + self -> + RaceTestUtils.race(() -> self.observe(consumer), () -> self.observe(consumer2))) + .assertSubscribeCalled(1) + .assertPendingSubscribers(2) + .then(self -> self.complete(valueToSend)) + .assertPendingSubscribers(0) + .assertReceivedExactly(valueToSend) + .assertNothingExpired() + .assertDisposeCalled(0) + .then( + self -> { + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); + + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); + + Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); + + Assertions.assertThat(self.add(consumer)).isEqualTo(ResolvingOperator.READY_STATE); + }); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenSubscribeAndBlock() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .whenSubscribe(self -> self.complete(valueToSend)) + .then( + self -> + RaceTestUtils.race( + () -> { + subscriber.onNext(self.block(null)); + subscriber.onComplete(); + }, + () -> self.observe(consumer2))) + .assertSubscribeCalled(1) + .assertPendingSubscribers(0) + .assertReceivedExactly(valueToSend) + .assertNothingExpired() + .assertDisposeCalled(0) + .then( + self -> { + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); + + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); + + Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); + + Assertions.assertThat(self.add(consumer2)).isEqualTo(ResolvingOperator.READY_STATE); + }); + } + } + + @Test + public void shouldEstablishValueOnceInCaseOfRacingBetweenBlocks() { + Duration timeout = Duration.ofMillis(100); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final String valueToSend = "value" + i; + + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .whenSubscribe(self -> self.complete(valueToSend)) + .then( + self -> + RaceTestUtils.race( + () -> { + subscriber.onNext(self.block(timeout)); + subscriber.onComplete(); + }, + () -> { + subscriber2.onNext(self.block(timeout)); + subscriber2.onComplete(); + })) + .assertSubscribeCalled(1) + .assertPendingSubscribers(0) + .assertReceivedExactly(valueToSend) + .assertNothingExpired() + .assertDisposeCalled(0) + .then( + self -> { + Assertions.assertThat(subscriber.isTerminated()).isTrue(); + Assertions.assertThat(subscriber2.isTerminated()).isTrue(); + + Assertions.assertThat(subscriber.values()).containsExactly(valueToSend); + Assertions.assertThat(subscriber2.values()).containsExactly(valueToSend); + + Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.READY); + + Assertions.assertThat(self.add((v, t) -> {})) + .isEqualTo(ResolvingOperator.READY_STATE); + }); + } + } + + @Test + public void shouldExpireValueOnRacingDisposeAndError() { + Hooks.onErrorDropped(t -> {}); + RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber subscriber = AssertSubscriber.create(); + subscriber.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer = + (v, t) -> { + if (t != null) { + subscriber.onError(t); + return; + } + + subscriber.onNext(v); + subscriber.onComplete(); + }; + + AssertSubscriber subscriber2 = AssertSubscriber.create(); + subscriber2.onSubscribe(Operators.emptySubscription()); + BiConsumer consumer2 = + (v, t) -> { + if (t != null) { + subscriber2.onError(t); + return; + } + + subscriber2.onNext(v); + subscriber2.onComplete(); + }; + + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .thenAddObserver(consumer) + .assertSubscribeCalled(1) + .assertPendingSubscribers(1) + .then(self -> RaceTestUtils.race(() -> self.terminate(runtimeException), self::dispose)) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertDisposeCalled(1) + .then( + self -> { + Assertions.assertThat(self.subscribers).isEqualTo(ResolvingOperator.TERMINATED); + + Assertions.assertThat(self.add((v, t) -> {})) + .isEqualTo(ResolvingOperator.TERMINATED_STATE); + }) + .thenAddObserver(consumer2); + + subscriber + .await(Duration.ofMillis(10)) + .assertErrorWith( + t -> { + if (t instanceof CancellationException) { + Assertions.assertThat(t) + .isInstanceOf(CancellationException.class) + .hasMessage("Disposed"); + } else { + Assertions.assertThat(t).isInstanceOf(RuntimeException.class).hasMessage("test"); + } + }); + + subscriber2 + .await(Duration.ofMillis(10)) + .assertErrorWith( + t -> { + if (t instanceof CancellationException) { + Assertions.assertThat(t) + .isInstanceOf(CancellationException.class) + .hasMessage("Disposed"); + } else { + Assertions.assertThat(t).isInstanceOf(RuntimeException.class).hasMessage("test"); + } + }); + + // no way to guarantee equality because of racing + // Assertions.assertThat(processor.getError()) + // .isEqualTo(processor2.getError()); + } + } + + @Test + public void shouldThrowOnBlocking() { + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .then( + self -> + Assertions.assertThatThrownBy(() -> self.block(Duration.ofMillis(100))) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on Mono blocking read")) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertNothingReceived() + .assertDisposeCalled(0); + } + + @Test + public void shouldThrowOnBlockingIfHasAlreadyTerminated() { + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingSubscribers(0) + .assertPendingResolution() + .whenSubscribe(self -> self.terminate(new RuntimeException("test"))) + .then( + self -> + Assertions.assertThatThrownBy(() -> self.block(Duration.ofMillis(100))) + .isInstanceOf(RuntimeException.class) + .hasMessage("test") + .hasSuppressedException(new Exception("Terminated with an error"))) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertNothingReceived() + .assertDisposeCalled(1); + } + + static Stream, Publisher>> innerCases() { + return Stream.of( + (self) -> { + final Sinks.One processor = Sinks.unsafe().one(); + final ResolvingOperator.DeferredResolution operator = + new ResolvingOperator.DeferredResolution( + self, new SinkOneSubscriber(processor)) { + @Override + public void accept(String v, Throwable t) { + if (t != null) { + onError(t); + return; + } + + onNext(v); + } + }; + return processor + .asMono() + .doOnSubscribe(s -> self.observe(operator)) + .doOnCancel(operator::cancel); + }, + (self) -> { + final Sinks.One processor = Sinks.unsafe().one(); + final SinkOneSubscriber subscriber = new SinkOneSubscriber(processor); + final ResolvingOperator.MonoDeferredResolutionOperator operator = + new ResolvingOperator.MonoDeferredResolutionOperator<>(self, subscriber); + subscriber.onSubscribe(operator); + return processor + .asMono() + .doOnSubscribe(s -> self.observe(operator)) + .doOnCancel(operator::cancel); + }); + } + + @ParameterizedTest + @MethodSource("innerCases") + public void shouldBePossibleToRemoveThemSelvesFromTheList_CancellationTest( + Function, Publisher> caseProducer) { + ResolvingTest.create() + .then( + self -> { + Publisher resolvingInner = caseProducer.apply(self); + StepVerifier.create(resolvingInner) + .expectSubscription() + .then(() -> self.assertSubscribeCalled(1).assertPendingSubscribers(1)) + .thenCancel() + .verify(Duration.ofMillis(100)); + }) + .assertPendingSubscribers(0) + .assertNothingExpired() + .then(self -> self.complete("test")) + .assertReceivedExactly("test"); + } + + @ParameterizedTest + @MethodSource("innerCases") + public void shouldExpireValueOnDispose( + Function, Publisher> caseProducer) { + ResolvingTest.create() + .then( + self -> { + Publisher resolvingInner = caseProducer.apply(self); + + StepVerifier.create(resolvingInner) + .expectSubscription() + .then(() -> self.complete("test")) + .expectNext("test") + .expectComplete() + .verify(Duration.ofMillis(100)); + }) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertReceivedExactly("test") + .then(ResolvingOperator::dispose) + .assertExpiredExactly("test") + .assertDisposeCalled(1); + } + + @ParameterizedTest + @MethodSource("innerCases") + public void shouldNotifyAllTheSubscribers( + Function, Publisher> caseProducer) { + + AssertSubscriber sub1 = AssertSubscriber.create(); + AssertSubscriber sub2 = AssertSubscriber.create(); + AssertSubscriber sub3 = AssertSubscriber.create(); + AssertSubscriber sub4 = AssertSubscriber.create(); + + final ArrayList> processors = + new ArrayList<>(RaceTestConstants.REPEATS * 2); + + ResolvingTest.create() + .assertDisposeCalled(0) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertNothingReceived() + .assertPendingResolution() + .then( + self -> { + caseProducer.apply(self).subscribe(sub1); + caseProducer.apply(self).subscribe(sub2); + caseProducer.apply(self).subscribe(sub3); + caseProducer.apply(self).subscribe(sub4); + }) + .assertSubscribeCalled(1) + .assertPendingSubscribers(4) + .then( + self -> { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + AssertSubscriber subA = AssertSubscriber.create(); + AssertSubscriber subB = AssertSubscriber.create(); + processors.add(subA); + processors.add(subB); + RaceTestUtils.race( + () -> caseProducer.apply(self).subscribe(subA), + () -> caseProducer.apply(self).subscribe(subB)); + } + }) + .assertSubscribeCalled(1) + .assertPendingSubscribers(RaceTestConstants.REPEATS * 2 + 4) + .then(self -> sub1.cancel()) + .assertPendingSubscribers(RaceTestConstants.REPEATS * 2 + 3) + .then( + self -> { + String valueToSend = "value"; + self.complete(valueToSend); + + Assertions.assertThat(sub1.isTerminated()).isFalse(); + Assertions.assertThat(sub2.values()).containsExactly(valueToSend); + Assertions.assertThat(sub3.values()).containsExactly(valueToSend); + Assertions.assertThat(sub4.values()).containsExactly(valueToSend); + + for (AssertSubscriber sub : processors) { + Assertions.assertThat(sub.values()).containsExactly(valueToSend); + Assertions.assertThat(sub.isTerminated()).isTrue(); + } + }) + .assertPendingSubscribers(0) + .assertNothingExpired() + .assertReceivedExactly("value"); + } + + @Test + public void shouldBeSerialIfRacyMonoInner() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + long[] requested = new long[] {0}; + Subscription mockSubscription = Mockito.mock(Subscription.class); + Mockito.doAnswer( + a -> { + long argument = a.getArgument(0); + return requested[0] += argument; + }) + .when(mockSubscription) + .request(Mockito.anyLong()); + ResolvingOperator.DeferredResolution resolution = + new ResolvingOperator.DeferredResolution( + ResolvingTest.create(), AssertSubscriber.create(0)) { + + @Override + public void accept(Object o, Object o2) {} + }; + + resolution.request(5); + + RaceTestUtils.race( + () -> resolution.onSubscribe(mockSubscription), + () -> { + resolution.request(10); + resolution.request(10); + resolution.request(10); + }); + + resolution.request(15); + + Assertions.assertThat(requested[0]).isEqualTo(50L); + } + } + + @Test + public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidates() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingResolution() + .then(self -> self.complete("test")) + .assertReceivedExactly("test") + .then(self -> RaceTestUtils.race(self::invalidate, self::invalidate)) + .assertExpiredExactly("test"); + } + } + + @Test + public void shouldExpireValueExactlyOnceOnRacingBetweenInvalidateAndDispose() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + ResolvingTest.create() + .assertNothingExpired() + .assertNothingReceived() + .assertPendingResolution() + .then(self -> self.complete("test")) + .assertReceivedExactly("test") + .then(self -> RaceTestUtils.race(self::invalidate, self::dispose)) + .assertExpiredExactly("test"); + } + } + + static class ResolvingTest extends ResolvingOperator { + + final AtomicInteger subscribeCalls = new AtomicInteger(); + final AtomicInteger onDisposeCalls = new AtomicInteger(); + + final Queue received = new ConcurrentLinkedQueue<>(); + final Queue expired = new ConcurrentLinkedQueue<>(); + + Consumer> whenSubscribeConsumer = (self) -> {}; + + static ResolvingTest create() { + return new ResolvingTest<>(); + } + + public ResolvingTest assertPendingSubscribers(int cnt) { + Assertions.assertThat(this.subscribers.length).isEqualTo(cnt); + + return this; + } + + public ResolvingTest whenSubscribe(Consumer> consumer) { + this.whenSubscribeConsumer = consumer; + return this; + } + + public ResolvingTest then(Consumer> consumer) { + consumer.accept(this); + + return this; + } + + public ResolvingTest thenAddObserver(BiConsumer consumer) { + this.observe(consumer); + return this; + } + + public ResolvingTest assertPendingResolution() { + Assertions.assertThat(this.isPending()).isTrue(); + + return this; + } + + public ResolvingTest assertIsDisposed() { + Assertions.assertThat(this.isDisposed()).isTrue(); + + return this; + } + + public ResolvingTest assertSubscribeCalled(int times) { + Assertions.assertThat(subscribeCalls).hasValue(times); + + return this; + } + + public ResolvingTest assertDisposeCalled(int times) { + Assertions.assertThat(onDisposeCalls).hasValue(times); + return this; + } + + public ResolvingTest assertNothingExpired() { + return assertExpiredExactly(); + } + + public ResolvingTest assertExpiredExactly(T... values) { + Assertions.assertThat(expired).hasSize(values.length).containsExactly(values); + + return this; + } + + public ResolvingTest assertNothingReceived() { + return assertReceivedExactly(); + } + + public ResolvingTest assertReceivedExactly(T... values) { + Assertions.assertThat(received).hasSize(values.length).containsExactly(values); + + return this; + } + + public ResolvingTest ifResolvedAssertEqual(T value) { + if (received.size() > 0) { + Assertions.assertThat(received).hasSize(1).containsExactly(value); + } + + return this; + } + + @Override + protected void doOnValueResolved(T value) { + received.offer(value); + } + + @Override + protected void doOnValueExpired(T value) { + expired.offer(value); + } + + @Override + protected void doSubscribe() { + whenSubscribeConsumer.accept(this); + subscribeCalls.incrementAndGet(); + } + + @Override + protected void doOnDispose() { + onDisposeCalls.incrementAndGet(); + } + } + + private static class SinkOneSubscriber implements CoreSubscriber { + + private final Sinks.One processor; + private boolean valueReceived; + + public SinkOneSubscriber(Sinks.One processor) { + this.processor = processor; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(String s) { + valueReceived = true; + processor.tryEmitValue(s); + } + + @Override + public void onError(Throwable t) { + processor.tryEmitError(t); + } + + @Override + public void onComplete() { + if (!valueReceived) { + processor.tryEmitEmpty(); + } + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java new file mode 100755 index 000000000..4f7821e4a --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ResponderOperatorsCommonTest.java @@ -0,0 +1,477 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameType.METADATA_PUSH; +import static io.rsocket.frame.FrameType.NEXT; +import static io.rsocket.frame.FrameType.REQUEST_CHANNEL; +import static io.rsocket.frame.FrameType.REQUEST_FNF; +import static io.rsocket.frame.FrameType.REQUEST_RESPONSE; +import static io.rsocket.frame.FrameType.REQUEST_STREAM; + +import io.netty.buffer.ByteBuf; +import io.rsocket.FrameAssert; +import io.rsocket.Payload; +import io.rsocket.PayloadAssert; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.plugins.TestRequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import java.util.ArrayList; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.test.publisher.TestPublisher; + +public class ResponderOperatorsCommonTest { + + interface Scenario { + FrameType requestType(); + + int maxElements(); + + ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler); + + ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler); + } + + static Stream scenarios() { + return Stream.of( + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_RESPONSE; + } + + @Override + public int maxElements() { + return 1; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber( + streamId, firstFragment, streamManager, handler); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestResponseResponderSubscriber subscriber = + new RequestResponseResponderSubscriber(streamId, streamManager); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_RESPONSE, null); + } + + return handler.requestResponse(firstPayload).subscribeWith(subscriber); + } + + @Override + public String toString() { + return RequestResponseRequesterMono.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_STREAM; + } + + @Override + public int maxElements() { + return Integer.MAX_VALUE; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber( + streamId, initialRequestN, firstFragment, streamManager, handler); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + + streamManager.activeStreams.put(streamId, subscriber); + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestStreamResponderSubscriber subscriber = + new RequestStreamResponderSubscriber(streamId, initialRequestN, streamManager); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_STREAM, null); + } + + return handler.requestStream(firstPayload).subscribeWith(subscriber); + } + + @Override + public String toString() { + return RequestStreamResponderSubscriber.class.getSimpleName(); + } + }, + new Scenario() { + @Override + public FrameType requestType() { + return FrameType.REQUEST_CHANNEL; + } + + @Override + public int maxElements() { + return Integer.MAX_VALUE; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + ByteBuf firstFragment, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestChannelResponderSubscriber subscriber = + new RequestChannelResponderSubscriber( + streamId, initialRequestN, firstFragment, streamManager, handler); + streamManager.activeStreams.put(streamId, subscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + + return subscriber; + } + + @Override + public ResponderFrameHandler responseOperator( + long initialRequestN, + Payload firstPayload, + TestRequesterResponderSupport streamManager, + RSocket handler) { + int streamId = streamManager.getNextStreamId(); + RequestChannelResponderSubscriber responderSubscriber = + new RequestChannelResponderSubscriber( + streamId, initialRequestN, firstPayload, streamManager); + streamManager.activeStreams.put(streamId, responderSubscriber); + + final RequestInterceptor requestInterceptor = streamManager.getRequestInterceptor(); + if (requestInterceptor != null) { + requestInterceptor.onStart(streamId, REQUEST_CHANNEL, null); + } + + return handler.requestChannel(responderSubscriber).subscribeWith(responderSubscriber); + } + + @Override + public String toString() { + return RequestChannelResponderSubscriber.class.getSimpleName(); + } + }); + } + + static class TestHandler implements RSocket { + + final TestPublisher producer; + final AssertSubscriber consumer; + + TestHandler(TestPublisher producer, AssertSubscriber consumer) { + this.producer = producer; + this.consumer = consumer; + } + + @Override + public Mono fireAndForget(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.mono().then(); + } + + @Override + public Mono requestResponse(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.mono(); + } + + @Override + public Flux requestStream(Payload payload) { + consumer.onSubscribe(Operators.emptySubscription()); + consumer.onNext(payload); + consumer.onComplete(); + return producer.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + payloads.subscribe(consumer); + return producer.flux(); + } + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, + TestRequesterResponderSupport.genericPayload(allocator), + testRequesterResponderSupport, + testHandler); + + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + testPublisher.assertWasSubscribed(); + testPublisher.next(randomPayload.retain()); + testPublisher.complete(); + + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .hasStreamId(1) + .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) + .hasPayloadSize( + randomPayload.data().readableBytes() + randomPayload.sliceMetadata().readableBytes()) + .hasData(randomPayload.data()) + .hasNoLeaks(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + if (scenario.requestType() != REQUEST_RESPONSE) { + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + testHandler.consumer.request(2); + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.REQUEST_N) + .hasStreamId(1) + .hasRequestN(1) + .hasNoLeaks(); + + responderFrameHandler.handleComplete(); + testHandler.consumer.assertComplete(); + } + } + + testHandler + .consumer + .assertValueCount(1) + .assertValuesWith(p -> PayloadAssert.assertThat(p).hasNoLeaks()); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleFragmentedRequest(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + final TestDuplexConnection sender = testRequesterResponderSupport.getDuplexConnection(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, firstPayload); + + ByteBuf firstFragment = fragments.remove(0); + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, firstFragment, testRequesterResponderSupport, testHandler); + firstFragment.release(); + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertHasStream(1, responderFrameHandler); + + for (int i = 0; i < fragments.size(); i++) { + ByteBuf fragment = fragments.get(i); + boolean hasFollows = i != fragments.size() - 1; + responderFrameHandler.handleNext(fragment, hasFollows, !hasFollows); + fragment.release(); + } + + Payload randomPayload = TestRequesterResponderSupport.randomPayload(allocator); + testPublisher.assertWasSubscribed(); + testPublisher.next(randomPayload.retain()); + testPublisher.complete(); + + FrameAssert.assertThat(sender.awaitFrame()) + .isNotNull() + .hasStreamId(1) + .typeOf(scenario.requestType() == REQUEST_RESPONSE ? FrameType.NEXT_COMPLETE : NEXT) + .hasPayloadSize( + randomPayload.data().readableBytes() + randomPayload.sliceMetadata().readableBytes()) + .hasData(randomPayload.data()) + .hasNoLeaks(); + + PayloadAssert.assertThat(randomPayload).hasNoLeaks(); + + if (scenario.requestType() != REQUEST_RESPONSE) { + + FrameAssert.assertThat(sender.awaitFrame()) + .typeOf(FrameType.COMPLETE) + .hasStreamId(1) + .hasNoLeaks(); + + if (scenario.requestType() == REQUEST_CHANNEL) { + testHandler.consumer.request(2); + FrameAssert.assertThat(sender.pollFrame()).isNull(); + } + } + + testHandler + .consumer + .assertValueCount(1) + .assertValuesWith( + p -> PayloadAssert.assertThat(p).hasData(firstPayload.sliceData()).hasNoLeaks()) + .assertComplete(); + + testRequesterResponderSupport.assertNoActiveStreams(); + + firstPayload.release(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnComplete(1) + .expectNothing(); + + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("scenarios") + void shouldHandleInterruptedFragmentation(Scenario scenario) { + Assumptions.assumeThat(scenario.requestType()).isNotIn(REQUEST_FNF, METADATA_PUSH); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + TestRequesterResponderSupport testRequesterResponderSupport = + TestRequesterResponderSupport.client(testRequestInterceptor); + final LeaksTrackingByteBufAllocator allocator = testRequesterResponderSupport.getAllocator(); + TestPublisher testPublisher = TestPublisher.create(); + TestHandler testHandler = new TestHandler(testPublisher, new AssertSubscriber<>(0)); + + int mtu = ThreadLocalRandom.current().nextInt(64, 256); + Payload firstPayload = TestRequesterResponderSupport.randomPayload(allocator); + ArrayList fragments = + TestRequesterResponderSupport.prepareFragments(allocator, mtu, firstPayload); + firstPayload.release(); + + ByteBuf firstFragment = fragments.remove(0); + ResponderFrameHandler responderFrameHandler = + scenario.responseOperator( + Long.MAX_VALUE, firstFragment, testRequesterResponderSupport, testHandler); + firstFragment.release(); + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertHasStream(1, responderFrameHandler); + + for (int i = 0; i < fragments.size(); i++) { + ByteBuf fragment = fragments.get(i); + boolean hasFollows = i != fragments.size() - 1; + if (hasFollows) { + responderFrameHandler.handleNext(fragment, true, false); + } else { + responderFrameHandler.handleCancel(); + } + fragment.release(); + } + + testPublisher.assertWasNotSubscribed(); + testRequesterResponderSupport.assertNoActiveStreams(); + + testRequestInterceptor + .expectOnStart(1, scenario.requestType()) + .expectOnCancel(1) + .expectNothing(); + + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java new file mode 100644 index 000000000..9a51b9419 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/SendUtilsTest.java @@ -0,0 +1,31 @@ +package io.rsocket.core; + +import static org.mockito.Mockito.*; + +import io.netty.util.ReferenceCounted; +import java.util.function.Consumer; +import org.junit.jupiter.api.Test; + +public class SendUtilsTest { + + @Test + void droppedElementsConsumerShouldAcceptOtherTypesThanReferenceCounted() { + Consumer value = extractDroppedElementConsumer(); + value.accept(new Object()); + } + + @Test + void droppedElementsConsumerReleaseReference() { + ReferenceCounted referenceCounted = mock(ReferenceCounted.class); + when(referenceCounted.release()).thenReturn(true); + + Consumer value = extractDroppedElementConsumer(); + value.accept(referenceCounted); + + verify(referenceCounted).release(); + } + + private static Consumer extractDroppedElementConsumer() { + return (Consumer) SendUtils.DISCARD_CONTEXT.stream().findAny().get().getValue(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java new file mode 100644 index 000000000..87c3a865f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/SetupRejectionTest.java @@ -0,0 +1,209 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.transport.ServerTransport.ConnectionAcceptor; +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.exceptions.Exceptions; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.SetupFrameCodec; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.transport.ServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +public class SetupRejectionTest { + + @Test + void responderRejectSetup() { + SingleConnectionTransport transport = new SingleConnectionTransport(); + + String errorMsg = "error"; + RejectingAcceptor acceptor = new RejectingAcceptor(errorMsg); + RSocketServer.create().acceptor(acceptor).bind(transport).block(); + + transport.connect(); + + ByteBuf sentFrame = transport.awaitSent(); + assertThat(FrameHeaderCodec.frameType(sentFrame)).isEqualTo(FrameType.ERROR); + RuntimeException error = Exceptions.from(0, sentFrame); + sentFrame.release(); + assertThat(errorMsg).isEqualTo(error.getMessage()); + assertThat(error).isInstanceOf(RejectedSetupException.class); + RSocket acceptorSender = acceptor.senderRSocket().block(); + assertThat(acceptorSender.isDisposed()).isTrue(); + transport.allocator.assertHasNoLeaks(); + } + + @Test + void requesterStreamsTerminatedOnZeroErrorFrame() { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); + Sinks.Empty onThisSideClosedSink = Sinks.empty(); + + RSocketRequester rSocket = + new RSocketRequester( + conn, + DefaultPayload::create, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + null, + onThisSideClosedSink, + onThisSideClosedSink.asMono()); + + String errorMsg = "error"; + + StepVerifier.create( + rSocket + .requestResponse(DefaultPayload.create("test")) + .doOnRequest( + ignored -> + conn.addToReceivedBuffer( + ErrorFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 0, + new RejectedSetupException(errorMsg))))) + .expectErrorMatches( + err -> err instanceof RejectedSetupException && errorMsg.equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + + assertThat(rSocket.isDisposed()).isTrue(); + allocator.assertHasNoLeaks(); + } + + @Test + void requesterNewStreamsTerminatedAfterZeroErrorFrame() { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + TestDuplexConnection conn = new TestDuplexConnection(allocator); + Sinks.Empty onThisSideClosedSink = Sinks.empty(); + RSocketRequester rSocket = + new RSocketRequester( + conn, + DefaultPayload::create, + StreamIdSupplier.clientSupplier(), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + 0, + 0, + null, + __ -> null, + null, + onThisSideClosedSink, + onThisSideClosedSink.asMono()); + + conn.addToReceivedBuffer( + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 0, new RejectedSetupException("error"))); + + StepVerifier.create( + rSocket + .requestResponse(DefaultPayload.create("test")) + .delaySubscription(Duration.ofMillis(100))) + .expectErrorMatches( + err -> err instanceof RejectedSetupException && "error".equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + allocator.assertHasNoLeaks(); + } + + private static class RejectingAcceptor implements SocketAcceptor { + private final String errorMessage; + private final Sinks.Many senderRSockets = + Sinks.many().unicast().onBackpressureBuffer(); + + public RejectingAcceptor(String errorMessage) { + this.errorMessage = errorMessage; + } + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + senderRSockets.tryEmitNext(sendingSocket); + return Mono.error(new RuntimeException(errorMessage)); + } + + public Mono senderRSocket() { + return senderRSockets.asFlux().next(); + } + } + + private static class SingleConnectionTransport implements ServerTransport { + + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + private final TestDuplexConnection conn = new TestDuplexConnection(allocator); + + @Override + public Mono start(ConnectionAcceptor acceptor) { + return Mono.just(new TestCloseable(acceptor, conn)); + } + + public ByteBuf awaitSent() { + return conn.awaitFrame(); + } + + public void connect() { + Payload payload = DefaultPayload.create(DefaultPayload.EMPTY_BUFFER); + ByteBuf setup = SetupFrameCodec.encode(allocator, false, 0, 42, "mdMime", "dMime", payload); + + conn.addToReceivedBuffer(setup); + } + } + + private static class TestCloseable implements Closeable { + + private final DuplexConnection conn; + + TestCloseable(ConnectionAcceptor acceptor, DuplexConnection conn) { + this.conn = conn; + Mono.from(acceptor.apply(conn)).subscribe(notUsed -> {}, err -> conn.dispose()); + } + + @Override + public Mono onClose() { + return conn.onClose(); + } + + @Override + public void dispose() { + conn.dispose(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java b/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java new file mode 100644 index 000000000..88e0dc8e2 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ShouldHaveFlag.java @@ -0,0 +1,98 @@ +package io.rsocket.core; + +import static io.rsocket.core.StateUtils.REQUEST_MASK; +import static io.rsocket.core.StateUtils.SUBSCRIBED_FLAG; +import static io.rsocket.core.StateUtils.extractRequestN; + +import java.util.HashMap; +import java.util.Map; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.error.ErrorMessageFactory; + +class ShouldHaveFlag extends BasicErrorMessageFactory { + + static final Map FLAGS_NAMES = + new HashMap() { + { + put(StateUtils.UNSUBSCRIBED_STATE, "UNSUBSCRIBED"); + put(StateUtils.TERMINATED_STATE, "TERMINATED"); + put(SUBSCRIBED_FLAG, "SUBSCRIBED"); + put(StateUtils.REQUEST_MASK, "REQUESTED(%s)"); + put(StateUtils.FIRST_FRAME_SENT_FLAG, "FIRST_FRAME_SENT"); + put(StateUtils.REASSEMBLING_FLAG, "REASSEMBLING"); + put(StateUtils.INBOUND_TERMINATED_FLAG, "INBOUND_TERMINATED"); + put(StateUtils.OUTBOUND_TERMINATED_FLAG, "OUTBOUND_TERMINATED"); + } + }; + + static final String SHOULD_HAVE_FLAG = "Expected state\n\t%s\nto have\n\t%s\nbut had\n\t[%s]"; + + private ShouldHaveFlag(long currentState, String expectedFlag, String actualFlags) { + super(SHOULD_HAVE_FLAG, toBinaryString(currentState), expectedFlag, actualFlags); + } + + static ErrorMessageFactory shouldHaveFlag(long currentState, long expectedFlag) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag(currentState, FLAGS_NAMES.get(expectedFlag), stateAsString); + } + + static ErrorMessageFactory shouldHaveRequestN(long currentState, long expectedRequestN) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag( + currentState, + String.format( + FLAGS_NAMES.get(REQUEST_MASK), + expectedRequestN == Integer.MAX_VALUE ? "MAX" : expectedRequestN), + stateAsString); + } + + static ErrorMessageFactory shouldHaveRequestNBetween( + long currentState, long expectedRequestNMin, long expectedRequestNMax) { + String stateAsString = extractStateAsString(currentState); + return new ShouldHaveFlag( + currentState, + String.format( + FLAGS_NAMES.get(REQUEST_MASK), + (expectedRequestNMin == Integer.MAX_VALUE ? "MAX" : expectedRequestNMin) + + " - " + + (expectedRequestNMax == Integer.MAX_VALUE ? "MAX" : expectedRequestNMax)), + stateAsString); + } + + private static String extractStateAsString(long currentState) { + StringBuilder stringBuilder = new StringBuilder(); + long flag = 1L << 31; + for (int i = 0; i < 33; i++, flag <<= 1) { + if ((currentState & flag) == flag) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(FLAGS_NAMES.get(flag)); + } + } + long requestN = extractRequestN(currentState); + if (requestN > 0) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append( + String.format( + FLAGS_NAMES.get(REQUEST_MASK), requestN >= Integer.MAX_VALUE ? "MAX" : requestN)); + } + return stringBuilder.toString(); + } + + static String toBinaryString(long state) { + StringBuilder binaryString = new StringBuilder(Long.toBinaryString(state)); + + int diff = 64 - binaryString.length(); + for (int i = 0; i < diff; i++) { + binaryString.insert(0, "0"); + } + + binaryString.insert(33, "_"); + binaryString.insert(0, "0b"); + + return binaryString.toString(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java b/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java new file mode 100644 index 000000000..e281e548c --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/ShouldNotHaveFlag.java @@ -0,0 +1,73 @@ +package io.rsocket.core; + +import static io.rsocket.core.StateUtils.REQUEST_MASK; +import static io.rsocket.core.StateUtils.SUBSCRIBED_FLAG; +import static io.rsocket.core.StateUtils.extractRequestN; + +import java.util.HashMap; +import java.util.Map; +import org.assertj.core.error.BasicErrorMessageFactory; +import org.assertj.core.error.ErrorMessageFactory; + +class ShouldNotHaveFlag extends BasicErrorMessageFactory { + + static final Map FLAGS_NAMES = + new HashMap() { + { + put(StateUtils.UNSUBSCRIBED_STATE, "UNSUBSCRIBED"); + put(StateUtils.TERMINATED_STATE, "TERMINATED"); + put(SUBSCRIBED_FLAG, "SUBSCRIBED"); + put(StateUtils.REQUEST_MASK, "REQUESTED(%n)"); + put(StateUtils.FIRST_FRAME_SENT_FLAG, "FIRST_FRAME_SENT"); + put(StateUtils.REASSEMBLING_FLAG, "REASSEMBLING"); + put(StateUtils.INBOUND_TERMINATED_FLAG, "INBOUND_TERMINATED"); + put(StateUtils.OUTBOUND_TERMINATED_FLAG, "OUTBOUND_TERMINATED"); + } + }; + + static final String SHOULD_NOT_HAVE_FLAG = + "Expected state\n\t%s\nto not have\n\t%s\nbut had\n\t[%s]"; + + private ShouldNotHaveFlag(long currentState, long expectedFlag, String actualFlags) { + super( + SHOULD_NOT_HAVE_FLAG, + toBinaryString(currentState), + FLAGS_NAMES.get(expectedFlag), + actualFlags); + } + + static ErrorMessageFactory shouldNotHaveFlag(long currentState, long expectedFlag) { + StringBuilder stringBuilder = new StringBuilder(); + long flag = 1L << 31; + for (int i = 0; i < 33; i++, flag <<= 1) { + if ((currentState & flag) == flag) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(FLAGS_NAMES.get(flag)); + } + } + long requestN = extractRequestN(currentState); + if (requestN > 0) { + if (stringBuilder.length() > 0) { + stringBuilder.append(", "); + } + stringBuilder.append(String.format(FLAGS_NAMES.get(REQUEST_MASK), requestN)); + } + return new ShouldNotHaveFlag(currentState, expectedFlag, stringBuilder.toString()); + } + + static String toBinaryString(long state) { + StringBuilder binaryString = new StringBuilder(Long.toBinaryString(state)); + + int diff = 64 - binaryString.length(); + for (int i = 0; i < diff; i++) { + binaryString.insert(0, "0"); + } + + binaryString.insert(33, "_"); + binaryString.insert(0, "0b"); + + return binaryString.toString(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java b/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java new file mode 100644 index 000000000..64253984b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/StateAssert.java @@ -0,0 +1,161 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.core.ShouldHaveFlag.*; +import static io.rsocket.core.ShouldNotHaveFlag.shouldNotHaveFlag; +import static io.rsocket.core.StateUtils.*; + +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import org.assertj.core.api.AbstractAssert; +import org.assertj.core.internal.Failures; + +public class StateAssert extends AbstractAssert, AtomicLongFieldUpdater> { + + public static StateAssert assertThat(AtomicLongFieldUpdater updater, T instance) { + return new StateAssert<>(updater, instance); + } + + public static StateAssert assertThat( + FireAndForgetRequesterMono instance) { + return new StateAssert<>(FireAndForgetRequesterMono.STATE, instance); + } + + public static StateAssert assertThat( + RequestResponseRequesterMono instance) { + return new StateAssert<>(RequestResponseRequesterMono.STATE, instance); + } + + public static StateAssert assertThat( + RequestStreamRequesterFlux instance) { + return new StateAssert<>(RequestStreamRequesterFlux.STATE, instance); + } + + public static StateAssert assertThat( + RequestChannelRequesterFlux instance) { + return new StateAssert<>(RequestChannelRequesterFlux.STATE, instance); + } + + public static StateAssert assertThat( + RequestChannelResponderSubscriber instance) { + return new StateAssert<>(RequestChannelResponderSubscriber.STATE, instance); + } + + private final Failures failures = Failures.instance(); + private final T instance; + + public StateAssert(AtomicLongFieldUpdater updater, T instance) { + super(updater, StateAssert.class); + this.instance = instance; + } + + public StateAssert isUnsubscribed() { + long currentState = actual.get(instance); + if (isSubscribed(currentState) || StateUtils.isTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, UNSUBSCRIBED_STATE)); + } + return this; + } + + public StateAssert hasSubscribedFlagOnly() { + long currentState = actual.get(instance); + if (currentState != SUBSCRIBED_FLAG) { + throw failures.failure(info, shouldHaveFlag(currentState, SUBSCRIBED_FLAG)); + } + return this; + } + + public StateAssert hasSubscribedFlag() { + long currentState = actual.get(instance); + if (!isSubscribed(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, SUBSCRIBED_FLAG)); + } + return this; + } + + public StateAssert hasRequestN(long n) { + long currentState = actual.get(instance); + if (extractRequestN(currentState) != n) { + throw failures.failure(info, shouldHaveRequestN(currentState, n)); + } + return this; + } + + public StateAssert hasRequestNBetween(long min, long max) { + long currentState = actual.get(instance); + final long requestN = extractRequestN(currentState); + if (requestN < min || requestN > max) { + throw failures.failure(info, shouldHaveRequestNBetween(currentState, min, max)); + } + return this; + } + + public StateAssert hasFirstFrameSentFlag() { + long currentState = actual.get(instance); + if (!isFirstFrameSent(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, FIRST_FRAME_SENT_FLAG)); + } + return this; + } + + public StateAssert hasNoFirstFrameSentFlag() { + long currentState = actual.get(instance); + if (isFirstFrameSent(currentState)) { + throw failures.failure(info, shouldNotHaveFlag(currentState, FIRST_FRAME_SENT_FLAG)); + } + return this; + } + + public StateAssert hasReassemblingFlag() { + long currentState = actual.get(instance); + if (!isReassembling(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, REASSEMBLING_FLAG)); + } + return this; + } + + public StateAssert hasNoReassemblingFlag() { + long currentState = actual.get(instance); + if (isReassembling(currentState)) { + throw failures.failure(info, shouldNotHaveFlag(currentState, REASSEMBLING_FLAG)); + } + return this; + } + + public StateAssert hasInboundTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isInboundTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, INBOUND_TERMINATED_FLAG)); + } + return this; + } + + public StateAssert hasOutboundTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isOutboundTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, OUTBOUND_TERMINATED_FLAG)); + } + return this; + } + + public StateAssert isTerminated() { + long currentState = actual.get(instance); + if (!StateUtils.isTerminated(currentState)) { + throw failures.failure(info, shouldHaveFlag(currentState, TERMINATED_STATE)); + } + return this; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java new file mode 100644 index 000000000..16bd9f16e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/StreamIdSupplierTest.java @@ -0,0 +1,120 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.core; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import org.junit.jupiter.api.Test; + +public class StreamIdSupplierTest { + @Test + public void testClientSequence() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = StreamIdSupplier.clientSupplier(); + assertThat(s.nextStreamId(map)).isEqualTo(1); + assertThat(s.nextStreamId(map)).isEqualTo(3); + assertThat(s.nextStreamId(map)).isEqualTo(5); + } + + @Test + public void testServerSequence() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = StreamIdSupplier.serverSupplier(); + assertEquals(2, s.nextStreamId(map)); + assertEquals(4, s.nextStreamId(map)); + assertEquals(6, s.nextStreamId(map)); + } + + @Test + public void testClientIsValid() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = StreamIdSupplier.clientSupplier(); + + assertFalse(s.isBeforeOrCurrent(1)); + assertFalse(s.isBeforeOrCurrent(3)); + + s.nextStreamId(map); + assertTrue(s.isBeforeOrCurrent(1)); + assertFalse(s.isBeforeOrCurrent(3)); + + s.nextStreamId(map); + assertTrue(s.isBeforeOrCurrent(3)); + + // negative + assertFalse(s.isBeforeOrCurrent(-1)); + // connection + assertFalse(s.isBeforeOrCurrent(0)); + // server also accepted (checked externally) + assertTrue(s.isBeforeOrCurrent(2)); + } + + @Test + public void testServerIsValid() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = StreamIdSupplier.serverSupplier(); + + assertFalse(s.isBeforeOrCurrent(2)); + assertFalse(s.isBeforeOrCurrent(4)); + + s.nextStreamId(map); + assertTrue(s.isBeforeOrCurrent(2)); + assertFalse(s.isBeforeOrCurrent(4)); + + s.nextStreamId(map); + assertTrue(s.isBeforeOrCurrent(4)); + + // negative + assertFalse(s.isBeforeOrCurrent(-2)); + // connection + assertFalse(s.isBeforeOrCurrent(0)); + // client also accepted (checked externally) + assertTrue(s.isBeforeOrCurrent(1)); + } + + @Test + public void testWrap() { + IntObjectMap map = new IntObjectHashMap<>(); + StreamIdSupplier s = new StreamIdSupplier(Integer.MAX_VALUE - 3); + + assertEquals(2147483646, s.nextStreamId(map)); + assertEquals(2, s.nextStreamId(map)); + assertEquals(4, s.nextStreamId(map)); + + s = new StreamIdSupplier(Integer.MAX_VALUE - 2); + + assertEquals(2147483647, s.nextStreamId(map)); + assertEquals(1, s.nextStreamId(map)); + assertEquals(3, s.nextStreamId(map)); + } + + @Test + public void testSkipFound() { + IntObjectMap map = new IntObjectHashMap<>(); + map.put(5, new Object()); + map.put(9, new Object()); + StreamIdSupplier s = StreamIdSupplier.clientSupplier(); + assertEquals(1, s.nextStreamId(map)); + assertEquals(3, s.nextStreamId(map)); + assertEquals(7, s.nextStreamId(map)); + assertEquals(11, s.nextStreamId(map)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java new file mode 100644 index 000000000..e282d72d5 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/core/TestRequesterResponderSupport.java @@ -0,0 +1,281 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.core; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.RequestInterceptor; +import io.rsocket.test.util.TestDuplexConnection; +import io.rsocket.util.ByteBufPayload; +import java.util.ArrayList; +import java.util.concurrent.ThreadLocalRandom; +import org.assertj.core.api.Assertions; +import reactor.core.Exceptions; +import reactor.util.annotation.Nullable; + +final class TestRequesterResponderSupport extends RequesterResponderSupport implements RSocket { + + static final String DATA_CONTENT = "testData"; + static final String METADATA_CONTENT = "testMetadata"; + + final Throwable error; + + TestRequesterResponderSupport( + @Nullable Throwable error, + StreamIdSupplier streamIdSupplier, + DuplexConnection connection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + super( + mtu, + maxFrameLength, + maxInboundPayloadSize, + PayloadDecoder.ZERO_COPY, + connection, + streamIdSupplier, + (__) -> requestInterceptor); + this.error = error; + } + + @Override + public TestDuplexConnection getDuplexConnection() { + return (TestDuplexConnection) super.getDuplexConnection(); + } + + static Payload genericPayload(LeaksTrackingByteBufAllocator allocator) { + ByteBuf data = allocator.buffer(); + data.writeCharSequence(DATA_CONTENT, CharsetUtil.UTF_8); + + ByteBuf metadata = allocator.buffer(); + metadata.writeCharSequence(METADATA_CONTENT, CharsetUtil.UTF_8); + + return ByteBufPayload.create(data, metadata); + } + + static Payload fixedSizePayload(LeaksTrackingByteBufAllocator allocator, int contentSize) { + final int dataSize = ThreadLocalRandom.current().nextInt(0, contentSize); + final byte[] dataBytes = new byte[dataSize]; + ThreadLocalRandom.current().nextBytes(dataBytes); + ByteBuf data = allocator.buffer(dataSize); + data.writeBytes(dataBytes); + + ByteBuf metadata; + int metadataSize = contentSize - dataSize; + if (metadataSize > 0) { + final byte[] metadataBytes = new byte[metadataSize]; + metadata = allocator.buffer(metadataSize); + metadata.writeBytes(metadataBytes); + } else { + metadata = ThreadLocalRandom.current().nextBoolean() ? Unpooled.EMPTY_BUFFER : null; + } + + return ByteBufPayload.create(data, metadata); + } + + static Payload randomPayload(LeaksTrackingByteBufAllocator allocator) { + boolean hasMetadata = ThreadLocalRandom.current().nextBoolean(); + ByteBuf metadataByteBuf; + if (hasMetadata) { + byte[] randomMetadata = new byte[ThreadLocalRandom.current().nextInt(0, 512)]; + ThreadLocalRandom.current().nextBytes(randomMetadata); + metadataByteBuf = allocator.buffer().writeBytes(randomMetadata); + } else { + metadataByteBuf = null; + } + byte[] randomData = new byte[ThreadLocalRandom.current().nextInt(512, 1024)]; + ThreadLocalRandom.current().nextBytes(randomData); + + ByteBuf dataByteBuf = allocator.buffer().writeBytes(randomData); + return ByteBufPayload.create(dataByteBuf, metadataByteBuf); + } + + static Payload randomMetadataOnlyPayload(LeaksTrackingByteBufAllocator allocator) { + byte[] randomMetadata = new byte[ThreadLocalRandom.current().nextInt(512, 1024)]; + ThreadLocalRandom.current().nextBytes(randomMetadata); + ByteBuf metadataByteBuf = allocator.buffer().writeBytes(randomMetadata); + + return ByteBufPayload.create(Unpooled.EMPTY_BUFFER, metadataByteBuf); + } + + static ArrayList prepareFragments( + LeaksTrackingByteBufAllocator allocator, int mtu, Payload payload) { + + return prepareFragments(allocator, mtu, payload, FrameType.NEXT_COMPLETE); + } + + static ArrayList prepareFragments( + LeaksTrackingByteBufAllocator allocator, int mtu, Payload payload, FrameType frameType) { + + boolean hasMetadata = payload.hasMetadata(); + ByteBuf data = payload.sliceData(); + ByteBuf metadata = payload.sliceMetadata(); + ArrayList fragments = new ArrayList<>(); + + fragments.add( + frameType.hasInitialRequestN() + ? FragmentationUtils.encodeFirstFragment( + allocator, mtu, 1L, frameType, 1, hasMetadata, metadata, data) + : FragmentationUtils.encodeFirstFragment( + allocator, mtu, frameType, 1, hasMetadata, metadata, data)); + + while (metadata.isReadable() || data.isReadable()) { + fragments.add( + FragmentationUtils.encodeFollowsFragment(allocator, mtu, 1, true, metadata, data)); + } + + return fragments; + } + + @Override + public synchronized int getNextStreamId() { + int nextStreamId = super.getNextStreamId(); + + if (error != null) { + throw Exceptions.propagate(error); + } + + return nextStreamId; + } + + @Override + public synchronized int addAndGetNextStreamId(FrameHandler frameHandler) { + int nextStreamId = super.addAndGetNextStreamId(frameHandler); + + if (error != null) { + super.remove(nextStreamId, frameHandler); + throw Exceptions.propagate(error); + } + + return nextStreamId; + } + + public static TestRequesterResponderSupport client( + @Nullable Throwable e, @Nullable RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor, + e); + } + + public static TestRequesterResponderSupport client(@Nullable Throwable e) { + return client(0, FRAME_LENGTH_MASK, Integer.MAX_VALUE, e); + } + + public static TestRequesterResponderSupport client( + int mtu, int maxFrameLength, int maxInboundPayloadSize, @Nullable Throwable e) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + mtu, + maxFrameLength, + maxInboundPayloadSize, + null, + e); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize) { + return client(duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor) { + return client( + duplexConnection, mtu, maxFrameLength, maxInboundPayloadSize, requestInterceptor, null); + } + + public static TestRequesterResponderSupport client( + TestDuplexConnection duplexConnection, + int mtu, + int maxFrameLength, + int maxInboundPayloadSize, + @Nullable RequestInterceptor requestInterceptor, + @Nullable Throwable e) { + return new TestRequesterResponderSupport( + e, + StreamIdSupplier.clientSupplier(), + duplexConnection, + mtu, + maxFrameLength, + maxInboundPayloadSize, + requestInterceptor); + } + + public static TestRequesterResponderSupport client( + int mtu, int maxFrameLength, int maxInboundPayloadSize) { + return client(mtu, maxFrameLength, maxInboundPayloadSize, null); + } + + public static TestRequesterResponderSupport client(int mtu, int maxFrameLength) { + return client(mtu, maxFrameLength, Integer.MAX_VALUE); + } + + public static TestRequesterResponderSupport client(int mtu) { + return client(mtu, FRAME_LENGTH_MASK); + } + + public static TestRequesterResponderSupport client() { + return client(0); + } + + public static TestRequesterResponderSupport client(RequestInterceptor requestInterceptor) { + return client( + new TestDuplexConnection( + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT)), + 0, + FRAME_LENGTH_MASK, + Integer.MAX_VALUE, + requestInterceptor); + } + + public TestRequesterResponderSupport assertNoActiveStreams() { + Assertions.assertThat(activeStreams).isEmpty(); + return this; + } + + public TestRequesterResponderSupport assertHasStream(int i, FrameHandler stream) { + Assertions.assertThat(activeStreams).containsEntry(i, stream); + return this; + } + + @Override + public LeaksTrackingByteBufAllocator getAllocator() { + return (LeaksTrackingByteBufAllocator) super.getAllocator(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ApplicationErrorExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ApplicationErrorExceptionTest.java new file mode 100644 index 000000000..35b30b951 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ApplicationErrorExceptionTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class ApplicationErrorExceptionTest + implements RSocketExceptionTest { + + @Override + public ApplicationErrorException getException(String message) { + return new ApplicationErrorException(message); + } + + @Override + public ApplicationErrorException getException(String message, Throwable cause) { + return new ApplicationErrorException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000201; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/CanceledExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/CanceledExceptionTest.java new file mode 100644 index 000000000..6df9e6a4d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/CanceledExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class CanceledExceptionTest implements RSocketExceptionTest { + + @Override + public CanceledException getException(String message) { + return new CanceledException(message); + } + + @Override + public CanceledException getException(String message, Throwable cause) { + return new CanceledException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000203; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionCloseExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionCloseExceptionTest.java new file mode 100644 index 000000000..fe98b55de --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionCloseExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class ConnectionCloseExceptionTest implements RSocketExceptionTest { + + @Override + public ConnectionCloseException getException(String message) { + return new ConnectionCloseException(message); + } + + @Override + public ConnectionCloseException getException(String message, Throwable cause) { + return new ConnectionCloseException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000102; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionErrorExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionErrorExceptionTest.java new file mode 100644 index 000000000..a2bd45a38 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ConnectionErrorExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class ConnectionErrorExceptionTest implements RSocketExceptionTest { + + @Override + public ConnectionErrorException getException(String message) { + return new ConnectionErrorException(message); + } + + @Override + public ConnectionErrorException getException(String message, Throwable cause) { + return new ConnectionErrorException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000101; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java new file mode 100644 index 000000000..a316aed8b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/ExceptionsTest.java @@ -0,0 +1,283 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import static io.rsocket.frame.ErrorFrameCodec.APPLICATION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.CANCELED; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_CLOSE; +import static io.rsocket.frame.ErrorFrameCodec.CONNECTION_ERROR; +import static io.rsocket.frame.ErrorFrameCodec.INVALID; +import static io.rsocket.frame.ErrorFrameCodec.INVALID_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_RESUME; +import static io.rsocket.frame.ErrorFrameCodec.REJECTED_SETUP; +import static io.rsocket.frame.ErrorFrameCodec.UNSUPPORTED_SETUP; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.rsocket.RaceTestConstants; +import io.rsocket.frame.ErrorFrameCodec; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class ExceptionsTest { + @DisplayName("from returns ApplicationErrorException") + @Test + void fromApplicationException() { + ByteBuf byteBuf = createErrorFrame(1, APPLICATION_ERROR, "test-message"); + + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(ApplicationErrorException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", APPLICATION_ERROR, "test-message"); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns CanceledException") + @Test + void fromCanceledException() { + ByteBuf byteBuf = createErrorFrame(1, CANCELED, "test-message"); + + try { + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CanceledException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", CANCELED, "test-message"); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns ConnectionCloseException") + @Test + void fromConnectionCloseException() { + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_CLOSE, "test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(ConnectionCloseException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_CLOSE, "test-message"); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns ConnectionErrorException") + @Test + void fromConnectionErrorException() { + ByteBuf byteBuf = createErrorFrame(0, CONNECTION_ERROR, "test-message"); + + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(ConnectionErrorException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", CONNECTION_ERROR, "test-message"); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns IllegalArgumentException if error frame has illegal error code") + @Test + void fromIllegalErrorFrame() { + ByteBuf byteBuf = createErrorFrame(0, 0x00000000, "test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", 0, "test-message") + .isInstanceOf(IllegalArgumentException.class); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 1: 0x%08X '%s'", 0x00000000, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns InvalidException") + @Test + void fromInvalidException() { + ByteBuf byteBuf = createErrorFrame(1, INVALID, "test-message"); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(InvalidException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", INVALID, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns InvalidSetupException") + @Test + void fromInvalidSetupException() { + ByteBuf byteBuf = createErrorFrame(0, INVALID_SETUP, "test-message"); + try { + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(InvalidSetupException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", INVALID_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns RejectedException") + @Test + void fromRejectedException() { + ByteBuf byteBuf = createErrorFrame(1, REJECTED, "test-message"); + try { + + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(RejectedException.class) + .withFailMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage("Invalid Error frame in Stream ID 0: 0x%08X '%s'", REJECTED, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns RejectedResumeException") + @Test + void fromRejectedResumeException() { + ByteBuf byteBuf = createErrorFrame(0, REJECTED_RESUME, "test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(RejectedResumeException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_RESUME, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns RejectedSetupException") + @Test + void fromRejectedSetupException() { + ByteBuf byteBuf = createErrorFrame(0, REJECTED_SETUP, "test-message"); + try { + + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(RejectedSetupException.class) + .withFailMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", REJECTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns UnsupportedSetupException") + @Test + void fromUnsupportedSetupException() { + ByteBuf byteBuf = createErrorFrame(0, UNSUPPORTED_SETUP, "test-message"); + try { + assertThat(Exceptions.from(0, byteBuf)) + .isInstanceOf(UnsupportedSetupException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(1, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 1: 0x%08X '%s'", UNSUPPORTED_SETUP, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + + @DisplayName("from returns CustomRSocketException") + @Test + void fromCustomRSocketException() { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + int randomCode = + ThreadLocalRandom.current().nextBoolean() + ? ThreadLocalRandom.current() + .nextInt(Integer.MIN_VALUE, ErrorFrameCodec.MAX_USER_ALLOWED_ERROR_CODE) + : ThreadLocalRandom.current() + .nextInt(ErrorFrameCodec.MIN_USER_ALLOWED_ERROR_CODE, Integer.MAX_VALUE); + ByteBuf byteBuf = createErrorFrame(0, randomCode, "test-message"); + try { + assertThat(Exceptions.from(1, byteBuf)) + .isInstanceOf(CustomRSocketException.class) + .hasMessage("test-message"); + + assertThat(Exceptions.from(0, byteBuf)) + .hasMessage( + "Invalid Error frame in Stream ID 0: 0x%08X '%s'", randomCode, "test-message") + .isInstanceOf(IllegalArgumentException.class); + } finally { + byteBuf.release(); + } + } + } + + @DisplayName("from throws NullPointerException with null frame") + @Test + void fromWithNullFrame() { + assertThatNullPointerException() + .isThrownBy(() -> Exceptions.from(0, null)) + .withMessage("frame must not be null"); + } + + private ByteBuf createErrorFrame(int streamId, int errorCode, String message) { + return ErrorFrameCodec.encode( + UnpooledByteBufAllocator.DEFAULT, streamId, new TestRSocketException(errorCode, message)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidExceptionTest.java new file mode 100644 index 000000000..a7dec62b4 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class InvalidExceptionTest implements RSocketExceptionTest { + + @Override + public InvalidException getException(String message) { + return new InvalidException(message); + } + + @Override + public InvalidException getException(String message, Throwable cause) { + return new InvalidException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000204; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidSetupExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidSetupExceptionTest.java new file mode 100644 index 000000000..d7fce8cc8 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/InvalidSetupExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class InvalidSetupExceptionTest implements RSocketExceptionTest { + + @Override + public InvalidSetupException getException(String message) { + return new InvalidSetupException(message); + } + + @Override + public InvalidSetupException getException(String message, Throwable cause) { + return new InvalidSetupException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000001; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java new file mode 100644 index 000000000..9aa8fc364 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RSocketExceptionTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.RSocketErrorException; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +interface RSocketExceptionTest { + + @DisplayName("constructor does not throw NullPointerException with null message") + @Test + default void constructorWithNullMessage() { + assertThat(getException(null)).hasMessage(null); + } + + @DisplayName("constructor does not throw NullPointerException with null message and cause") + @Test + default void constructorWithNullMessageAndCause() { + assertThat(getException(null)).hasMessage(null); + } + + @DisplayName("errorCode returns specified value") + @Test + default void errorCodeReturnsSpecifiedValue() { + assertThat(getException("test-message").errorCode()).isEqualTo(getSpecifiedErrorCode()); + } + + T getException(String message, Throwable cause); + + T getException(String message); + + int getSpecifiedErrorCode(); +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedExceptionTest.java new file mode 100644 index 000000000..209595596 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class RejectedExceptionTest implements RSocketExceptionTest { + + @Override + public RejectedException getException(String message) { + return new RejectedException(message); + } + + @Override + public RejectedException getException(String message, Throwable cause) { + return new RejectedException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000202; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedResumeExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedResumeExceptionTest.java new file mode 100644 index 000000000..555ff160d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedResumeExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class RejectedResumeExceptionTest implements RSocketExceptionTest { + + @Override + public RejectedResumeException getException(String message) { + return new RejectedResumeException(message); + } + + @Override + public RejectedResumeException getException(String message, Throwable cause) { + return new RejectedResumeException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000004; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedSetupExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedSetupExceptionTest.java new file mode 100644 index 000000000..2fe63c09d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/RejectedSetupExceptionTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class RejectedSetupExceptionTest implements RSocketExceptionTest { + + @Override + public RejectedSetupException getException(String message) { + return new RejectedSetupException(message); + } + + @Override + public RejectedSetupException getException(String message, Throwable cause) { + return new RejectedSetupException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000003; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java new file mode 100644 index 000000000..15685aa43 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/TestRSocketException.java @@ -0,0 +1,42 @@ +package io.rsocket.exceptions; + +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; + +public class TestRSocketException extends RSocketErrorException { + private static final long serialVersionUID = 7873267740343446585L; + + private final int errorCode; + + /** + * Constructs a new exception with the specified message. + * + * @param errorCode customizable error code + * @param message the message + * @throws NullPointerException if {@code message} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message) { + super(ErrorFrameCodec.APPLICATION_ERROR, message); + this.errorCode = errorCode; + } + + /** + * Constructs a new exception with the specified message and cause. + * + * @param errorCode customizable error code + * @param message the message + * @param cause the cause of this exception + * @throws NullPointerException if {@code message} or {@code cause} is {@code null} + * @throws IllegalArgumentException if {@code errorCode} is out of allowed range + */ + public TestRSocketException(int errorCode, String message, Throwable cause) { + super(ErrorFrameCodec.APPLICATION_ERROR, message, cause); + this.errorCode = errorCode; + } + + @Override + public int errorCode() { + return errorCode; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/exceptions/UnsupportedSetupExceptionTest.java b/rsocket-core/src/test/java/io/rsocket/exceptions/UnsupportedSetupExceptionTest.java new file mode 100644 index 000000000..6c73ff564 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/exceptions/UnsupportedSetupExceptionTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.exceptions; + +final class UnsupportedSetupExceptionTest + implements RSocketExceptionTest { + + @Override + public UnsupportedSetupException getException(String message) { + return new UnsupportedSetupException(message); + } + + @Override + public UnsupportedSetupException getException(String message, Throwable cause) { + return new UnsupportedSetupException(message, cause); + } + + @Override + public int getSpecifiedErrorCode() { + return 0x00000002; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java new file mode 100644 index 000000000..b12d72b51 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ByteBufRepresentation.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; +import org.assertj.core.api.Assertions; +import org.assertj.core.presentation.StandardRepresentation; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; + +public final class ByteBufRepresentation extends StandardRepresentation + implements BeforeAllCallback { + + @Override + public void beforeAll(ExtensionContext context) { + Assertions.useRepresentation(this); + } + + @Override + protected String fallbackToStringOf(Object object) { + if (object instanceof ByteBuf) { + try { + String normalBufferString = object.toString(); + ByteBuf byteBuf = (ByteBuf) object; + if (byteBuf.readableBytes() <= 128) { + String prettyHexDump = ByteBufUtil.prettyHexDump(byteBuf); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } else { + return normalBufferString; + } + } catch (IllegalReferenceCountException e) { + // noops + } + } + + return super.fallbackToStringOf(object); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java new file mode 100644 index 000000000..dc04c1141 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ErrorFrameCodecTest.java @@ -0,0 +1,21 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.exceptions.ApplicationErrorException; +import org.junit.jupiter.api.Test; + +class ErrorFrameCodecTest { + @Test + void testEncode() { + ByteBuf frame = + ErrorFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, new ApplicationErrorException("d")); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + assertEquals("00000b000000012c000000020164", ByteBufUtil.hexDump(frame)); + frame.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java new file mode 100644 index 000000000..28209393e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ExtensionFrameCodecTest.java @@ -0,0 +1,62 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ExtensionFrameCodecTest { + + @Test + void extensionDataMetadata() { + ByteBuf metadata = bytebuf("md"); + ByteBuf data = bytebuf("d"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, extendedType, metadata, data); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertEquals(metadata, ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(data, ExtensionFrameCodec.data(extension)); + extension.release(); + } + + @Test + void extensionData() { + ByteBuf data = bytebuf("d"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, extendedType, null, data); + + Assertions.assertFalse(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertNull(ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(data, ExtensionFrameCodec.data(extension)); + extension.release(); + } + + @Test + void extensionMetadata() { + ByteBuf metadata = bytebuf("md"); + int extendedType = 1; + + ByteBuf extension = + ExtensionFrameCodec.encode( + ByteBufAllocator.DEFAULT, 1, extendedType, metadata, Unpooled.EMPTY_BUFFER); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(extension)); + Assertions.assertEquals(extendedType, ExtensionFrameCodec.extendedType(extension)); + Assertions.assertEquals(metadata, ExtensionFrameCodec.metadata(extension)); + Assertions.assertEquals(0, ExtensionFrameCodec.data(extension).readableBytes()); + extension.release(); + } + + private static ByteBuf bytebuf(String str) { + return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java new file mode 100644 index 000000000..15788e631 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/FrameHeaderCodecTest.java @@ -0,0 +1,36 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.junit.jupiter.api.Test; + +class FrameHeaderCodecTest { + // Taken from spec + private static final int FRAME_MAX_SIZE = 16_777_215; + + @Test + void typeAndFlag() { + FrameType frameType = FrameType.REQUEST_FNF; + int flags = 0b1110110111; + ByteBuf header = FrameHeaderCodec.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); + + assertEquals(flags, FrameHeaderCodec.flags(header)); + assertEquals(frameType, FrameHeaderCodec.frameType(header)); + header.release(); + } + + @Test + void typeAndFlagTruncated() { + FrameType frameType = FrameType.SETUP; + int flags = 0b11110110111; // 1 bit too many + ByteBuf header = FrameHeaderCodec.encode(ByteBufAllocator.DEFAULT, 0, frameType, flags); + + assertNotEquals(flags, FrameHeaderCodec.flags(header)); + assertEquals(flags & 0b0000_0011_1111_1111, FrameHeaderCodec.flags(header)); + assertEquals(frameType, FrameHeaderCodec.frameType(header)); + header.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java new file mode 100644 index 000000000..ac19dc754 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/GenericFrameCodecTest.java @@ -0,0 +1,264 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; + +class GenericFrameCodecTest { + @Test + void testEncoding() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length + // | | | | ⌌Encoded Metadata + // | | | | | ⌌Encoded Data + // __|________|_________|______|____|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "000010000000011900000000010000026d6464"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void testEncodingWithEmptyMetadata() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + Unpooled.EMPTY_BUFFER, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Metadata Length (0) + // | | | | ⌌Encoded Data + // __|________|_________|_______|___| + // ↓ ↓↓ ↓↓ ↓↓ ↓↓↓ + String expected = "00000e0000000119000000000100000064"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void testEncodingWithNullMetadata() { + ByteBuf frame = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 1, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + + // Encoded FrameLength⌍ ⌌ Encoded Headers + // | | ⌌ Encoded Request(1) + // | | | ⌌Encoded Data + // __|________|_________|_____| + // ↓<-> ↓↓ <-> ↓↓ <-> ↓↓↓ + String expected = "00000b0000000118000000000164"; + assertEquals(expected, ByteBufUtil.hexDump(frame)); + frame.release(); + } + + @Test + void requestResponseDataMetadata() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestResponseFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = RequestResponseFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestResponseData() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestResponseFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestResponseFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertNull(metadata); + request.release(); + } + + @Test + void requestResponseMetadata() { + ByteBuf request = + RequestResponseFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + ByteBuf data = RequestResponseFrameCodec.data(request); + String metadata = RequestResponseFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertTrue(data.readableBytes() == 0); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestStreamDataMetadata() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Integer.MAX_VALUE + 1L, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + String data = RequestStreamFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = RequestStreamFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals(Long.MAX_VALUE, actualRequest); + assertEquals("md", metadata); + assertEquals("d", data); + request.release(); + } + + @Test + void requestStreamData() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 42, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + String data = RequestStreamFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestStreamFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals(42L, actualRequest); + assertNull(metadata); + assertEquals("d", data); + request.release(); + } + + @Test + void requestStreamMetadata() { + ByteBuf request = + RequestStreamFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + 42, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + long actualRequest = RequestStreamFrameCodec.initialRequestN(request); + ByteBuf data = RequestStreamFrameCodec.data(request); + String metadata = RequestStreamFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals(42L, actualRequest); + assertTrue(data.readableBytes() == 0); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestFnfDataAndMetadata() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestFireAndForgetFrameCodec.data(request).toString(StandardCharsets.UTF_8); + String metadata = + RequestFireAndForgetFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertEquals("md", metadata); + request.release(); + } + + @Test + void requestFnfData() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + null, + Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + + String data = RequestFireAndForgetFrameCodec.data(request).toString(StandardCharsets.UTF_8); + ByteBuf metadata = RequestFireAndForgetFrameCodec.metadata(request); + + assertFalse(FrameHeaderCodec.hasMetadata(request)); + assertEquals("d", data); + assertNull(metadata); + request.release(); + } + + @Test + void requestFnfMetadata() { + ByteBuf request = + RequestFireAndForgetFrameCodec.encode( + ByteBufAllocator.DEFAULT, + 1, + false, + Unpooled.copiedBuffer("md", StandardCharsets.UTF_8), + Unpooled.EMPTY_BUFFER); + + ByteBuf data = RequestFireAndForgetFrameCodec.data(request); + String metadata = + RequestFireAndForgetFrameCodec.metadata(request).toString(StandardCharsets.UTF_8); + + assertTrue(FrameHeaderCodec.hasMetadata(request)); + assertEquals("md", metadata); + assertTrue(data.readableBytes() == 0); + request.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java new file mode 100644 index 000000000..bc013e024 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/KeepaliveFrameFlyweightTest.java @@ -0,0 +1,32 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Test; + +class KeepaliveFrameFlyweightTest { + @Test + void canReadData() { + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + ByteBuf frame = KeepAliveFrameCodec.encode(ByteBufAllocator.DEFAULT, true, 0, data); + assertTrue(KeepAliveFrameCodec.respondFlag(frame)); + assertEquals(data, KeepAliveFrameCodec.data(frame)); + frame.release(); + } + + @Test + void testEncoding() { + ByteBuf frame = + KeepAliveFrameCodec.encode( + ByteBufAllocator.DEFAULT, true, 0, Unpooled.copiedBuffer("d", StandardCharsets.UTF_8)); + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + assertEquals("00000f000000000c80000000000000000064", ByteBufUtil.hexDump(frame)); + frame.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java new file mode 100644 index 000000000..73c3bde5e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/LeaseFrameCodecTest.java @@ -0,0 +1,42 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class LeaseFrameCodecTest { + + @Test + void leaseMetadata() { + ByteBuf metadata = bytebuf("md"); + int ttl = 1; + int numRequests = 42; + ByteBuf lease = LeaseFrameCodec.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, metadata); + + Assertions.assertTrue(FrameHeaderCodec.hasMetadata(lease)); + Assertions.assertEquals(ttl, LeaseFrameCodec.ttl(lease)); + Assertions.assertEquals(numRequests, LeaseFrameCodec.numRequests(lease)); + Assertions.assertEquals(metadata, LeaseFrameCodec.metadata(lease)); + lease.release(); + } + + @Test + void leaseAbsentMetadata() { + int ttl = 1; + int numRequests = 42; + ByteBuf lease = LeaseFrameCodec.encode(ByteBufAllocator.DEFAULT, ttl, numRequests, null); + + Assertions.assertFalse(FrameHeaderCodec.hasMetadata(lease)); + Assertions.assertEquals(ttl, LeaseFrameCodec.ttl(lease)); + Assertions.assertEquals(numRequests, LeaseFrameCodec.numRequests(lease)); + Assertions.assertNull(LeaseFrameCodec.metadata(lease)); + lease.release(); + } + + private static ByteBuf bytebuf(String str) { + return Unpooled.copiedBuffer(str, StandardCharsets.UTF_8); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java new file mode 100644 index 000000000..aecbb31ce --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/PayloadFlyweightTest.java @@ -0,0 +1,88 @@ +package io.rsocket.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.StandardCharsets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class PayloadFlyweightTest { + + @Test + void nextCompleteDataMetadata() { + Payload payload = DefaultPayload.create("d", "md"); + ByteBuf nextComplete = + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(nextComplete).toString(StandardCharsets.UTF_8); + String metadata = PayloadFrameCodec.metadata(nextComplete).toString(StandardCharsets.UTF_8); + Assertions.assertEquals("d", data); + Assertions.assertEquals("md", metadata); + nextComplete.release(); + } + + @Test + void nextCompleteData() { + Payload payload = DefaultPayload.create("d"); + ByteBuf nextComplete = + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(nextComplete).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(nextComplete); + Assertions.assertEquals("d", data); + Assertions.assertNull(metadata); + nextComplete.release(); + } + + @Test + void nextCompleteMetaData() { + Payload payload = + DefaultPayload.create( + Unpooled.EMPTY_BUFFER, Unpooled.wrappedBuffer("md".getBytes(StandardCharsets.UTF_8))); + + ByteBuf nextComplete = + PayloadFrameCodec.encodeNextCompleteReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + ByteBuf data = PayloadFrameCodec.data(nextComplete); + String metadata = PayloadFrameCodec.metadata(nextComplete).toString(StandardCharsets.UTF_8); + Assertions.assertTrue(data.readableBytes() == 0); + Assertions.assertEquals("md", metadata); + nextComplete.release(); + } + + @Test + void nextDataMetadata() { + Payload payload = DefaultPayload.create("d", "md"); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + String metadata = PayloadFrameCodec.metadata(next).toString(StandardCharsets.UTF_8); + Assertions.assertEquals("d", data); + Assertions.assertEquals("md", metadata); + next.release(); + } + + @Test + void nextData() { + Payload payload = DefaultPayload.create("d"); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(next); + Assertions.assertEquals("d", data); + Assertions.assertNull(metadata); + next.release(); + } + + @Test + void nextDataEmptyMetadata() { + Payload payload = DefaultPayload.create("d".getBytes(), new byte[0]); + ByteBuf next = + PayloadFrameCodec.encodeNextReleasingPayload(ByteBufAllocator.DEFAULT, 1, payload); + String data = PayloadFrameCodec.data(next).toString(StandardCharsets.UTF_8); + ByteBuf metadata = PayloadFrameCodec.metadata(next); + Assertions.assertEquals("d", data); + Assertions.assertEquals(metadata.readableBytes(), 0); + next.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java new file mode 100644 index 000000000..e38258040 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/RequestNFrameCodecTest.java @@ -0,0 +1,19 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import org.junit.jupiter.api.Test; + +class RequestNFrameCodecTest { + @Test + void testEncoding() { + ByteBuf frame = RequestNFrameCodec.encode(ByteBufAllocator.DEFAULT, 1, 5); + + frame = FrameLengthCodec.encode(ByteBufAllocator.DEFAULT, frame.readableBytes(), frame); + assertEquals("00000a00000001200000000005", ByteBufUtil.hexDump(frame)); + frame.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java new file mode 100644 index 000000000..4815bfb8e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeFrameCodecTest.java @@ -0,0 +1,41 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +public class ResumeFrameCodecTest { + + @Test + void testEncoding() { + byte[] tokenBytes = new byte[65000]; + Arrays.fill(tokenBytes, (byte) 1); + ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); + ByteBuf byteBuf = ResumeFrameCodec.encode(ByteBufAllocator.DEFAULT, token, 21, 12); + assertThat(ResumeFrameCodec.version(byteBuf)).isEqualTo(ResumeFrameCodec.CURRENT_VERSION); + assertThat(ResumeFrameCodec.token(byteBuf)).isEqualTo(token); + assertThat(ResumeFrameCodec.lastReceivedServerPos(byteBuf)).isEqualTo(21); + assertThat(ResumeFrameCodec.firstAvailableClientPos(byteBuf)).isEqualTo(12); + byteBuf.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java new file mode 100644 index 000000000..b818d579d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/ResumeOkFrameCodecTest.java @@ -0,0 +1,17 @@ +package io.rsocket.frame; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.junit.jupiter.api.Test; + +public class ResumeOkFrameCodecTest { + + @Test + public void testEncoding() { + ByteBuf byteBuf = ResumeOkFrameCodec.encode(ByteBufAllocator.DEFAULT, 42); + assertThat(ResumeOkFrameCodec.lastReceivedClientPos(byteBuf)).isEqualTo(42); + byteBuf.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java new file mode 100644 index 000000000..3317b4618 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/SetupFrameCodecTest.java @@ -0,0 +1,57 @@ +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.util.DefaultPayload; +import java.util.Arrays; +import org.junit.jupiter.api.Test; + +class SetupFrameCodecTest { + @Test + void testEncodingNoResume() { + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + Payload payload = DefaultPayload.create(data, metadata); + ByteBuf frame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, false, 5, 500, "metadata_type", "data_type", payload); + + assertEquals(FrameType.SETUP, FrameHeaderCodec.frameType(frame)); + assertFalse(SetupFrameCodec.resumeEnabled(frame)); + assertEquals(0, SetupFrameCodec.resumeToken(frame).readableBytes()); + assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); + assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); + assertEquals(payload.metadata(), SetupFrameCodec.metadata(frame)); + assertEquals(payload.data(), SetupFrameCodec.data(frame)); + assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); + frame.release(); + } + + @Test + void testEncodingResume() { + byte[] tokenBytes = new byte[65000]; + Arrays.fill(tokenBytes, (byte) 1); + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + ByteBuf data = Unpooled.wrappedBuffer(new byte[] {5, 4, 3}); + Payload payload = DefaultPayload.create(data, metadata); + ByteBuf token = Unpooled.wrappedBuffer(tokenBytes); + ByteBuf frame = + SetupFrameCodec.encode( + ByteBufAllocator.DEFAULT, true, 5, 500, token, "metadata_type", "data_type", payload); + + assertEquals(FrameType.SETUP, FrameHeaderCodec.frameType(frame)); + assertTrue(SetupFrameCodec.honorLease(frame)); + assertTrue(SetupFrameCodec.resumeEnabled(frame)); + assertEquals(token, SetupFrameCodec.resumeToken(frame)); + assertEquals("metadata_type", SetupFrameCodec.metadataMimeType(frame)); + assertEquals("data_type", SetupFrameCodec.dataMimeType(frame)); + assertEquals(payload.metadata(), SetupFrameCodec.metadata(frame)); + assertEquals(payload.data(), SetupFrameCodec.data(frame)); + assertEquals(SetupFrameCodec.CURRENT_VERSION, SetupFrameCodec.version(frame)); + frame.release(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java b/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java new file mode 100644 index 000000000..be7fb837b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/VersionCodecTest.java @@ -0,0 +1,48 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; + +public class VersionCodecTest { + @Test + public void simple() { + int version = VersionCodec.encode(1, 0); + assertEquals(1, VersionCodec.major(version)); + assertEquals(0, VersionCodec.minor(version)); + assertEquals(0x00010000, version); + assertEquals("1.0", VersionCodec.toString(version)); + } + + @Test + public void complex() { + int version = VersionCodec.encode(0x1234, 0x5678); + assertEquals(0x1234, VersionCodec.major(version)); + assertEquals(0x5678, VersionCodec.minor(version)); + assertEquals(0x12345678, version); + assertEquals("4660.22136", VersionCodec.toString(version)); + } + + @Test + public void noShortOverflow() { + int version = VersionCodec.encode(43210, 43211); + assertEquals(43210, VersionCodec.major(version)); + assertEquals(43211, VersionCodec.minor(version)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/frame/old/LeaseFrameFlyweightTest.java b/rsocket-core/src/test/java/io/rsocket/frame/old/LeaseFrameFlyweightTest.java new file mode 100644 index 000000000..ef4fcc6b0 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/frame/old/LeaseFrameFlyweightTest.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.frame.old; + +public class LeaseFrameFlyweightTest { + /*private final ByteBuf byteBuf = Unpooled.buffer(1024); + + @Test + public void size() { + ByteBuf metadata = Unpooled.wrappedBuffer(new byte[] {1, 2, 3, 4}); + int length = LeaseFrameFlyweight.encode(byteBuf, 0, 0, metadata); + assertEquals(length, 9 + 4 * 2 + 4); // Frame header + ttl + #requests + 4 byte metadata + } + + @Test + public void testEncoding() { + int encoded = + LeaseFrameFlyweight.encode( + byteBuf, 0, 0, Unpooled.copiedBuffer("md", StandardCharsets.UTF_8)); + assertEquals( + "00001000000000090000000000000000006d64", ByteBufUtil.hexDump(byteBuf, 0, encoded)); + }*/ +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java new file mode 100644 index 000000000..d73f92b85 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/SchedulerUtils.java @@ -0,0 +1,23 @@ +package io.rsocket.internal; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import reactor.core.scheduler.Scheduler; + +public class SchedulerUtils { + + public static void warmup(Scheduler scheduler) throws InterruptedException { + warmup(scheduler, 10000); + } + + public static void warmup(Scheduler scheduler, int times) throws InterruptedException { + scheduler.start(); + + // warm up + CountDownLatch latch = new CountDownLatch(times); + for (int i = 0; i < times; i++) { + scheduler.schedule(latch::countDown); + } + latch.await(5, TimeUnit.SECONDS); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java new file mode 100644 index 000000000..343a93beb --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/UnboundedProcessorTest.java @@ -0,0 +1,366 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.RaceTestConstants; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.time.Duration; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Fuseable; +import reactor.core.publisher.Hooks; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.test.util.RaceTestUtils; + +public class UnboundedProcessorTest { + + @BeforeAll + public static void setup() { + Hooks.onErrorDropped(__ -> {}); + } + + @AfterAll + public static void teardown() { + Hooks.resetOnErrorDropped(); + } + + @ParameterizedTest( + name = + "Test that emitting {0} onNext before subscribe and requestN should deliver all the signals once the subscriber is available") + @ValueSource(ints = {10, 100, 10_000, 100_000, 1_000_000, 10_000_000}) + public void testOnNextBeforeSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor(); + + for (int i = 0; i < n; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } + + processor.onComplete(); + + StepVerifier.create(processor.count()).expectNext(Long.valueOf(n)).verifyComplete(); + } + + @ParameterizedTest( + name = + "Test that emitting {0} onNext after subscribe and requestN should deliver all the signals") + @ValueSource(ints = {10, 100, 10_000}) + public void testOnNextAfterSubscribeN(int n) { + UnboundedProcessor processor = new UnboundedProcessor(); + AssertSubscriber assertSubscriber = AssertSubscriber.create(); + + processor.subscribe(assertSubscriber); + + for (int i = 0; i < n; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } + + assertSubscriber.awaitAndAssertNextValueCount(n); + } + + @ParameterizedTest( + name = + "Test that prioritized value sending deliver prioritized signals before the others mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void testPrioritizedSending(boolean fusedCase) { + UnboundedProcessor processor = new UnboundedProcessor(); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + processor.onNext(Unpooled.EMPTY_BUFFER); + } + + processor.onNextPrioritized(Unpooled.copiedBuffer("test", CharsetUtil.UTF_8)); + + assertThat(fusedCase ? processor.poll() : processor.next().block()) + .isNotNull() + .extracting(bb -> bb.toString(CharsetUtil.UTF_8)) + .isEqualTo("test"); + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | cancel | request(n) will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void ensureUnboundedProcessorDisposesQueueProperly(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNext(buffer2); + }, + unboundedProcessor::dispose, + assertSubscriber::cancel, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | cancel | request(n) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest1(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + assertSubscriber::cancel, + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe | request(n) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest2(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> { + unboundedProcessor.subscribe(assertSubscriber); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe(cancelled) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest3(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + assertSubscriber.cancel(); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> unboundedProcessor.subscribe(assertSubscriber)); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext | dispose | subscribe(cancelled) | terminal will not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void smokeTest31(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + final RuntimeException runtimeException = new RuntimeException("test"); + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber(0) + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + RaceTestUtils.race( + Schedulers.boundedElastic(), + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNextPrioritized(buffer2); + }, + () -> { + unboundedProcessor.onNextPrioritized(buffer3); + unboundedProcessor.onNext(buffer4); + }, + unboundedProcessor::dispose, + unboundedProcessor::onComplete, + () -> unboundedProcessor.onError(runtimeException), + () -> unboundedProcessor.subscribe(assertSubscriber), + () -> { + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + assertSubscriber.request(1); + }, + assertSubscriber::cancel); + + assertSubscriber.values().forEach(ReferenceCountUtil::release); + allocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest( + name = + "Ensures that racing between onNext + dispose | downstream async drain should not cause any issues and leaks; mode[fusionEnabled={0}]") + @ValueSource(booleans = {true, false}) + public void ensuresAsyncFusionAndDisposureHasNoDeadlock(boolean withFusionEnabled) { + final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final UnboundedProcessor unboundedProcessor = new UnboundedProcessor(); + final ByteBuf buffer1 = allocator.buffer(1); + final ByteBuf buffer2 = allocator.buffer(2); + final ByteBuf buffer3 = allocator.buffer(3); + final ByteBuf buffer4 = allocator.buffer(4); + final ByteBuf buffer5 = allocator.buffer(5); + final ByteBuf buffer6 = allocator.buffer(6); + + final AssertSubscriber assertSubscriber = + new AssertSubscriber() + .requestedFusionMode(withFusionEnabled ? Fuseable.ANY : Fuseable.NONE); + + unboundedProcessor.subscribe(assertSubscriber); + + RaceTestUtils.race( + () -> { + unboundedProcessor.onNext(buffer1); + unboundedProcessor.onNext(buffer2); + unboundedProcessor.onNext(buffer3); + unboundedProcessor.onNext(buffer4); + unboundedProcessor.onNext(buffer5); + unboundedProcessor.onNext(buffer6); + unboundedProcessor.dispose(); + }, + unboundedProcessor::dispose); + + assertSubscriber.await(Duration.ofSeconds(50)).values().forEach(ReferenceCountUtil::release); + } + + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java new file mode 100644 index 000000000..b6eac9835 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/internal/subscriber/AssertSubscriber.java @@ -0,0 +1,1277 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.internal.subscriber; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import java.util.function.BooleanSupplier; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.Scannable; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; +import reactor.util.context.Context; + +/** + * A Subscriber implementation that hosts assertion tests for its state and allows asynchronous + * cancellation and requesting. + * + *

To create a new instance of {@link AssertSubscriber}, you have the choice between these static + * methods: + * + *

    + *
  • {@link AssertSubscriber#create()}: create a new {@link AssertSubscriber} and requests an + * unbounded number of elements. + *
  • {@link AssertSubscriber#create(long)}: create a new {@link AssertSubscriber} and requests + * {@code n} elements (can be 0 if you want no initial demand). + *
+ * + *

If you are testing asynchronous publishers, don't forget to use one of the {@code await*()} + * methods to wait for the data to assert. + * + *

You can extend this class but only the onNext, onError and onComplete can be overridden. You + * can call {@link #request(long)} and {@link #cancel()} from any thread or from within the + * overridable methods but you should avoid calling the assertXXX methods asynchronously. + * + *

Usage: + * + *

{@code
+ * AssertSubscriber
+ *   .subscribe(publisher)
+ *   .await()
+ *   .assertValues("ABC", "DEF");
+ * }
+ * + * @param the value type. + * @author Sebastien Deleuze + * @author David Karnok + * @author Anatoly Kadyshev + * @author Stephane Maldini + * @author Brian Clozel + */ +public class AssertSubscriber implements CoreSubscriber, Subscription, Scannable { + + /** Default timeout for waiting next values to be received */ + public static final Duration DEFAULT_VALUES_TIMEOUT = Duration.ofSeconds(3); + + @SuppressWarnings("rawtypes") + private static final AtomicLongFieldUpdater REQUESTED = + AtomicLongFieldUpdater.newUpdater(AssertSubscriber.class, "requested"); + + @SuppressWarnings("rawtypes") + private static final AtomicIntegerFieldUpdater WIP = + AtomicIntegerFieldUpdater.newUpdater(AssertSubscriber.class, "wip"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater NEXT_VALUES = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, List.class, "values"); + + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater S = + AtomicReferenceFieldUpdater.newUpdater(AssertSubscriber.class, Subscription.class, "s"); + + private final Context context; + + private final List errors = new LinkedList<>(); + + private final CountDownLatch cdl = new CountDownLatch(1); + + volatile boolean done; + + volatile Subscription s; + + volatile long requested; + + volatile int wip; + + volatile List values = new LinkedList<>(); + + /** The fusion mode to request. */ + private int requestedFusionMode = -1; + + /** The established fusion mode. */ + private volatile int establishedFusionMode = -1; + + /** The fuseable QueueSubscription in case a fusion mode was specified. */ + private Fuseable.QueueSubscription qs; + + private int subscriptionCount = 0; + + private int completionCount = 0; + + private volatile long valueCount = 0L; + + private volatile long nextValueAssertedCount = 0L; + + private Duration valuesTimeout = DEFAULT_VALUES_TIMEOUT; + + private boolean valuesStorage = true; + + // + // ============================================================================================================== + // Static methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throws an {@link AssertionError} with the specified error message + * supplier. + * + * @param timeout the timeout duration + * @param errorMessageSupplier the error message supplier + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, Supplier errorMessageSupplier, BooleanSupplier conditionSupplier) { + + Objects.requireNonNull(errorMessageSupplier); + Objects.requireNonNull(conditionSupplier); + Objects.requireNonNull(timeout); + + long timeoutNs = timeout.toNanos(); + long startTime = System.nanoTime(); + do { + if (conditionSupplier.getAsBoolean()) { + return; + } + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } while (System.nanoTime() - startTime < timeoutNs); + throw new AssertionError(errorMessageSupplier.get()); + } + + /** + * Blocking method that waits until {@code conditionSupplier} returns true, or if it does not + * before the specified timeout, throw an {@link AssertionError} with the specified error message. + * + * @param timeout the timeout duration + * @param errorMessage the error message + * @param conditionSupplier condition to break out of the wait loop + * @throws AssertionError + */ + public static void await( + Duration timeout, final String errorMessage, BooleanSupplier conditionSupplier) { + await( + timeout, + new Supplier() { + @Override + public String get() { + return errorMessage; + } + }, + conditionSupplier); + } + + /** + * Create a new {@link AssertSubscriber} that requests an unbounded number of elements. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create() { + return new AssertSubscriber<>(); + } + + /** + * Create a new {@link AssertSubscriber} that requests initially {@code n} elements. You can then + * manage the demand with {@link Subscription#request(long)}. + * + *

Be sure at least a publisher has subscribed to it via {@link + * Publisher#subscribe(Subscriber)} before use assert methods. + * + * @param n Number of elements to request (can be 0 if you want no initial demand). + * @param the observed value type + * @return a fresh AssertSubscriber instance + */ + public static AssertSubscriber create(long n) { + return new AssertSubscriber<>(n); + } + + // + // ============================================================================================================== + // constructors + // + // ============================================================================================================== + + public AssertSubscriber() { + this(Context.empty(), Long.MAX_VALUE); + } + + public AssertSubscriber(long n) { + this(Context.empty(), n); + } + + public AssertSubscriber(Context context) { + this(context, Long.MAX_VALUE); + } + + public AssertSubscriber(Context context, long n) { + if (n < 0) { + throw new IllegalArgumentException("initialRequest >= required but it was " + n); + } + this.context = context; + REQUESTED.lazySet(this, n); + } + + // + // ============================================================================================================== + // Configuration + // + // ============================================================================================================== + + /** + * Enable or disabled the values storage. It is enabled by default, and can be disable in order to + * be able to perform performance benchmarks or tests with a huge amount values. + * + * @param enabled enable value storage? + * @return this + */ + public final AssertSubscriber configureValuesStorage(boolean enabled) { + this.valuesStorage = enabled; + return this; + } + + /** + * Configure the timeout in seconds for waiting next values to be received (3 seconds by default). + * + * @param timeout the new default value timeout duration + * @return this + */ + public final AssertSubscriber configureValuesTimeout(Duration timeout) { + this.valuesTimeout = timeout; + return this; + } + + /** + * Returns the established fusion mode or -1 if it was not enabled + * + * @return the fusion mode, see Fuseable constants + */ + public final int establishedFusionMode() { + return establishedFusionMode; + } + + // + // ============================================================================================================== + // Assertions + // + // ============================================================================================================== + + /** + * Assert a complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertComplete() { + assertNoError(); + int c = completionCount; + if (c == 0) { + throw new AssertionError("Not completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert the specified values have been received. Values storage should be enabled to use this + * method. + * + * @param expectedValues the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertContainValues(Set expectedValues) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + if (expectedValues.size() > values.size()) { + throw new AssertionError("Actual contains fewer elements" + values, null); + } + + Iterator expected = expectedValues.iterator(); + + for (; ; ) { + boolean n2 = expected.hasNext(); + if (n2) { + T t2 = expected.next(); + if (!values.contains(t2)) { + throw new AssertionError( + "The element is not contained in the " + + "received results" + + " = " + + valueAndClass(t2), + null); + } + } else { + break; + } + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertError() { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert an error signal has been received. + * + * @param clazz The class of the exception contained in the error signal + * @return this + */ + public final AssertSubscriber assertError(Class clazz) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + Throwable e = errors.get(0); + if (!clazz.isInstance(e)) { + throw new AssertionError( + "Error class incompatible: expected = " + clazz + ", actual = " + e, null); + } + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + errors, null); + } + return this; + } + + public final AssertSubscriber assertErrorMessage(String message) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + assertionError("No error", null); + } + if (s == 1) { + if (!Objects.equals(message, errors.get(0).getMessage())) { + assertionError( + "Error class incompatible: expected = \"" + + message + + "\", actual = \"" + + errors.get(0).getMessage() + + "\"", + null); + } + } + if (s > 1) { + assertionError("Multiple errors: " + s, null); + } + + return this; + } + + /** + * Assert an error signal has been received. + * + * @param expectation A method that can verify the exception contained in the error signal and + * throw an exception (like an {@link AssertionError}) if the exception is not valid. + * @return this + */ + public final AssertSubscriber assertErrorWith(Consumer expectation) { + assertNotComplete(); + int s = errors.size(); + if (s == 0) { + throw new AssertionError("No error", null); + } + if (s == 1) { + expectation.accept(errors.get(0)); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert that the upstream was a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertFuseableSource() { + if (qs == null) { + throw new AssertionError("Upstream was not Fuseable"); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionEnabled() { + if (establishedFusionMode != Fuseable.SYNC && establishedFusionMode != Fuseable.ASYNC) { + throw new AssertionError("Fusion was not enabled"); + } + return this; + } + + public final AssertSubscriber assertFusionMode(int expectedMode) { + if (establishedFusionMode != expectedMode) { + throw new AssertionError( + "Wrong fusion mode: expected: " + + fusionModeName(expectedMode) + + ", actual: " + + fusionModeName(establishedFusionMode)); + } + return this; + } + + /** + * Assert that the fusion mode was granted. + * + * @return this + */ + public final AssertSubscriber assertFusionRejected() { + if (establishedFusionMode != Fuseable.NONE) { + throw new AssertionError("Fusion was granted"); + } + return this; + } + + /** + * Assert no error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNoError() { + int s = errors.size(); + if (s == 1) { + Throwable e = errors.get(0); + String valueAndClass = e == null ? null : e + " (" + e.getClass().getSimpleName() + ")"; + throw new AssertionError("Error present: " + valueAndClass, null); + } + if (s > 1) { + throw new AssertionError("Multiple errors: " + s, null); + } + return this; + } + + /** + * Assert no values have been received. + * + * @return this + */ + public final AssertSubscriber assertNoValues() { + if (valueCount != 0) { + throw new AssertionError( + "No values expected but received: [length = " + values.size() + "] " + values, null); + } + return this; + } + + /** + * Assert that the upstream was not a Fuseable source. + * + * @return this + */ + public final AssertSubscriber assertNonFuseableSource() { + if (qs != null) { + throw new AssertionError("Upstream was Fuseable"); + } + return this; + } + + /** + * Assert no complete successfully signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotComplete() { + int c = completionCount; + if (c == 1) { + throw new AssertionError("Completed", null); + } + if (c > 1) { + throw new AssertionError("Multiple completions: " + c, null); + } + return this; + } + + /** + * Assert no subscription occurred. + * + * @return this + */ + public final AssertSubscriber assertNotSubscribed() { + int s = subscriptionCount; + + if (s == 1) { + throw new AssertionError("OnSubscribe called once", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert no complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertNotTerminated() { + if (cdl.getCount() == 0) { + throw new AssertionError("Terminated", null); + } + return this; + } + + /** + * Assert subscription occurred (once). + * + * @return this + */ + public final AssertSubscriber assertSubscribed() { + int s = subscriptionCount; + + if (s == 0) { + throw new AssertionError("OnSubscribe not called", null); + } + if (s > 1) { + throw new AssertionError("OnSubscribe called multiple times: " + s, null); + } + + return this; + } + + /** + * Assert either complete successfully or error signal has been received. + * + * @return this + */ + public final AssertSubscriber assertTerminated() { + if (cdl.getCount() != 0) { + throw new AssertionError("Not terminated", null); + } + return this; + } + + /** + * Assert {@code n} values has been received. + * + * @param n the expected value count + * @return this + */ + public final AssertSubscriber assertValueCount(long n) { + if (valueCount != n) { + throw new AssertionError( + "Different value count: expected = " + n + ", actual = " + valueCount, null); + } + return this; + } + + /** + * Assert the specified values have been received in the same order read by the passed {@link + * Iterable}. Values storage should be enabled to use this method. + * + * @param expectedSequence the values to assert + * @see #configureValuesStorage(boolean) + * @return this + */ + public final AssertSubscriber assertValueSequence(Iterable expectedSequence) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + Iterator actual = values.iterator(); + Iterator expected = expectedSequence.iterator(); + int i = 0; + for (; ; ) { + boolean n1 = actual.hasNext(); + boolean n2 = expected.hasNext(); + if (n1 && n2) { + T t1 = actual.next(); + T t2 = expected.next(); + if (!Objects.equals(t1, t2)) { + throw new AssertionError( + "The element with index " + + i + + " does not match: expected = " + + valueAndClass(t2) + + ", actual = " + + valueAndClass(t1), + null); + } + i++; + } else if (n1 && !n2) { + throw new AssertionError("Actual contains more elements" + values, null); + } else if (!n1 && n2) { + throw new AssertionError("Actual contains fewer elements: " + values, null); + } else { + break; + } + } + return this; + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectedValues the values to assert + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValues(T... expectedValues) { + return assertValueSequence(Arrays.asList(expectedValues)); + } + + /** + * Assert the specified values have been received in the declared order. Values storage should be + * enabled to use this method. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + * @see #configureValuesStorage(boolean) + */ + @SafeVarargs + public final AssertSubscriber assertValuesWith(Consumer... expectations) { + if (!valuesStorage) { + throw new IllegalStateException("Using assertNoValues() requires enabling values storage"); + } + final int expectedValueCount = expectations.length; + if (expectedValueCount != values.size()) { + throw new AssertionError( + "Different value count: expected = " + expectedValueCount + ", actual = " + valueCount, + null); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = values.get(i); + consumer.accept(actualValue); + } + return this; + } + + // + // ============================================================================================================== + // Await methods + // + // ============================================================================================================== + + /** + * Blocking method that waits until a complete successfully or error signal is received. + * + * @return this + */ + public final AssertSubscriber await() { + if (cdl.getCount() == 0) { + return this; + } + try { + cdl.await(); + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + return this; + } + + /** + * Blocking method that waits until a complete successfully or error signal is received or until a + * timeout occurs. + * + * @param timeout The timeout value + * @return this + */ + public final AssertSubscriber await(Duration timeout) { + if (cdl.getCount() == 0) { + return this; + } + try { + if (!cdl.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { + throw new AssertionError("No complete or error signal before timeout"); + } + return this; + } catch (InterruptedException ex) { + throw new AssertionError("Wait interrupted", ex); + } + } + + /** + * Blocking method that waits until {@code n} next values have been received. + * + * @param n the value count to assert + * @return this + */ + public final AssertSubscriber awaitAndAssertNextValueCount(final long n) { + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + n, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d", + valueCount - nextValueAssertedCount, n, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + n)); + nextValueAssertedCount += n; + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * values provided) to assert them. + * + * @param values the values to assert + * @return this + */ + @SafeVarargs + @SuppressWarnings("unchecked") + public final AssertSubscriber awaitAndAssertNextValues(T... values) { + final int expectedNum = values.length; + final List> expectations = new ArrayList<>(); + for (int i = 0; i < expectedNum; i++) { + final T expectedValue = values[i]; + expectations.add( + actualValue -> { + if (!actualValue.equals(expectedValue)) { + throw new AssertionError( + String.format( + "Expected Next signal: %s, but got: %s", expectedValue, actualValue)); + } + }); + } + awaitAndAssertNextValuesWith(expectations.toArray((Consumer[]) new Consumer[0])); + return this; + } + + /** + * Blocking method that waits until {@code n} next values have been received (n is the number of + * expectations provided) to assert them. + * + * @param expectations One or more methods that can verify the values and throw a exception (like + * an {@link AssertionError}) if the value is not valid. + * @return this + */ + @SafeVarargs + public final AssertSubscriber awaitAndAssertNextValuesWith(Consumer... expectations) { + valuesStorage = true; + final int expectedValueCount = expectations.length; + await( + valuesTimeout, + () -> { + if (valuesStorage) { + return String.format( + "%d out of %d next values received within %d, " + "values : %s", + valueCount - nextValueAssertedCount, + expectedValueCount, + valuesTimeout.toMillis(), + values.toString()); + } + return String.format( + "%d out of %d next values received within %d ms", + valueCount - nextValueAssertedCount, expectedValueCount, valuesTimeout.toMillis()); + }, + () -> valueCount >= (nextValueAssertedCount + expectedValueCount)); + List nextValuesSnapshot; + List empty = new ArrayList<>(); + for (; ; ) { + nextValuesSnapshot = values; + if (NEXT_VALUES.compareAndSet(this, values, empty)) { + break; + } + } + if (nextValuesSnapshot.size() < expectedValueCount) { + throw new AssertionError( + String.format( + "Expected %d number of signals but received %d", + expectedValueCount, nextValuesSnapshot.size())); + } + for (int i = 0; i < expectedValueCount; i++) { + Consumer consumer = expectations[i]; + T actualValue = nextValuesSnapshot.get(i); + consumer.accept(actualValue); + } + nextValueAssertedCount += expectedValueCount; + return this; + } + + // + // ============================================================================================================== + // Overrides + // + // ============================================================================================================== + + @Override + public void cancel() { + Subscription a = s; + if (a != Operators.cancelledSubscription()) { + a = S.getAndSet(this, Operators.cancelledSubscription()); + if (a != null && a != Operators.cancelledSubscription()) { + a.cancel(); + + if (establishedFusionMode == Fuseable.ASYNC) { + final int previousState = markWorkAdded(); + if (!isWorkInProgress(previousState)) { + clearAndFinalize(); + } + } + } + } + } + + final boolean isCancelled() { + return s == Operators.cancelledSubscription(); + } + + public final boolean isTerminated() { + return cdl.getCount() == 0; + } + + @Override + public void onComplete() { + done = true; + completionCount++; + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + + cdl.countDown(); + } + + @Override + public void onError(Throwable t) { + done = true; + errors.add(t); + + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + return; + } + + cdl.countDown(); + } + + @Override + public void onNext(T t) { + if (establishedFusionMode == Fuseable.ASYNC) { + drain(); + } else { + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } + + static boolean isFinalized(int state) { + return state == Integer.MIN_VALUE; + } + + static boolean isWorkInProgress(int state) { + return state > 0; + } + + int markWorkAdded() { + for (; ; ) { + int state = this.wip; + + if (isFinalized(state)) { + return state; + } + + if ((state & Integer.MAX_VALUE) == Integer.MAX_VALUE) { + return state; + } + int nextState = state + 1; + + if (WIP.compareAndSet(this, state, nextState)) { + return state; + } + } + } + + void clearAndFinalize() { + final Fuseable.QueueSubscription qs = this.qs; + for (; ; ) { + int state = this.wip; + + qs.clear(); + + if (WIP.compareAndSet(this, state, Integer.MIN_VALUE)) { + return; + } + } + } + + void drain() { + final int previousState = markWorkAdded(); + if (isWorkInProgress(previousState)) { + return; + } + + if (isFinalized(previousState)) { + qs.clear(); + return; + } + + T t; + int m = 1; + for (; ; ) { + if (isCancelled()) { + clearAndFinalize(); + break; + } + boolean done = this.done; + t = qs.poll(); + if (t == null) { + if (done) { + clearAndFinalize(); + cdl.countDown(); + return; + } + m = WIP.addAndGet(this, -m); + if (m == 0) { + break; + } + continue; + } + valueCount++; + if (valuesStorage) { + List nextValuesSnapshot; + for (; ; ) { + nextValuesSnapshot = values; + nextValuesSnapshot.add(t); + if (NEXT_VALUES.compareAndSet(this, nextValuesSnapshot, nextValuesSnapshot)) { + break; + } + } + } + } + } + + @Override + @SuppressWarnings("unchecked") + public void onSubscribe(Subscription s) { + subscriptionCount++; + int requestMode = requestedFusionMode; + if (requestMode >= 0) { + if (s instanceof Fuseable.QueueSubscription) { + this.qs = (Fuseable.QueueSubscription) s; + + int m = qs.requestFusion(requestMode); + establishedFusionMode = m; + + if (!setWithoutRequesting(s)) { + qs.clear(); + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + return; + } + + if (m == Fuseable.SYNC) { + for (; ; ) { + T v = qs.poll(); + if (v == null) { + onComplete(); + break; + } + + onNext(v); + } + } else { + requestDeferred(); + } + + return; + } + } + + if (!set(s)) { + if (!isCancelled()) { + errors.add(new IllegalStateException("Subscription already set: " + subscriptionCount)); + } + } + } + + @Override + public void request(long n) { + if (Operators.validate(n)) { + if (establishedFusionMode != Fuseable.SYNC) { + normalRequest(n); + } + } + } + + @Override + @NonNull + public Context currentContext() { + return context; + } + + /** + * Setup what fusion mode should be requested from the incoming Subscription if it happens to be + * QueueSubscription + * + * @param requestMode the mode to request, see Fuseable constants + * @return this + */ + public final AssertSubscriber requestedFusionMode(int requestMode) { + this.requestedFusionMode = requestMode; + return this; + } + + public Subscription upstream() { + return s; + } + + // + // ============================================================================================================== + // Non public methods + // + // ============================================================================================================== + + protected final void normalRequest(long n) { + Subscription a = s; + if (a != null) { + a.request(n); + } else { + Operators.addCap(REQUESTED, this, n); + + a = s; + + if (a != null) { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + a.request(r); + } + } + } + } + + /** Requests the deferred amount if not zero. */ + protected final void requestDeferred() { + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + } + + /** + * Atomically sets the single subscription and requests the missed amount from it. + * + * @param s + * @return false if this arbiter is cancelled or there was a subscription already set + */ + protected final boolean set(Subscription s) { + Objects.requireNonNull(s, "s"); + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + + long r = REQUESTED.getAndSet(this, 0L); + + if (r != 0L) { + s.request(r); + } + + return true; + } + + a = this.s; + + if (a != Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + + Operators.reportSubscriptionSet(); + return false; + } + + /** + * Sets the Subscription once but does not request anything. + * + * @param s the Subscription to set + * @return true if successful, false if the current subscription is not null + */ + protected final boolean setWithoutRequesting(Subscription s) { + Objects.requireNonNull(s, "s"); + for (; ; ) { + Subscription a = this.s; + if (a == Operators.cancelledSubscription()) { + s.cancel(); + return false; + } + if (a != null) { + s.cancel(); + Operators.reportSubscriptionSet(); + return false; + } + + if (S.compareAndSet(this, null, s)) { + return true; + } + } + } + + /** + * Prepares and throws an AssertionError exception based on the message, cause, the active state + * and the potential errors so far. + * + * @param message the message + * @param cause the optional Throwable cause + * @throws AssertionError as expected + */ + protected final void assertionError(String message, Throwable cause) { + StringBuilder b = new StringBuilder(); + + if (cdl.getCount() != 0) { + b.append("(active) "); + } + b.append(message); + + List err = errors; + if (!err.isEmpty()) { + b.append(" (+ ").append(err.size()).append(" errors)"); + } + AssertionError e = new AssertionError(b.toString(), cause); + + for (Throwable t : err) { + e.addSuppressed(t); + } + + throw e; + } + + protected final String fusionModeName(int mode) { + switch (mode) { + case -1: + return "Disabled"; + case Fuseable.NONE: + return "None"; + case Fuseable.SYNC: + return "Sync"; + case Fuseable.ASYNC: + return "Async"; + default: + return "Unknown(" + mode + ")"; + } + } + + protected final String valueAndClass(Object o) { + if (o == null) { + return null; + } + return o + " (" + o.getClass().getSimpleName() + ")"; + } + + public List values() { + return values; + } + + public List errors() { + return errors; + } + + public final AssertSubscriber assertNoEvents() { + return assertNoValues().assertNoError().assertNotComplete(); + } + + @SafeVarargs + public final AssertSubscriber assertIncomplete(T... values) { + return assertValues(values).assertNotComplete().assertNoError(); + } + + @Override + public Object scanUnsafe(Attr key) { + if (key == Attr.PARENT) { + return upstream(); + } + + boolean t = isTerminated(); + if (key == Attr.TERMINATED) return t; + if (key == Attr.ERROR) return (!errors.isEmpty() ? errors.get(0) : null); + if (key == Attr.PREFETCH) return Integer.MAX_VALUE; + if (key == Attr.CANCELLED) return isCancelled(); + if (key == Attr.RUN_STYLE) return Attr.RunStyle.SYNC; + + return null; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java new file mode 100644 index 000000000..9ebca34f7 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/lease/LeaseImplTest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.lease; + +public class LeaseImplTest { + // + // @Test + // public void emptyLeaseNoAvailability() { + // LeaseImpl empty = LeaseImpl.empty(); + // Assertions.assertTrue(empty.isEmpty()); + // Assertions.assertFalse(empty.isValid()); + // Assertions.assertEquals(0.0, empty.availability(), 1e-5); + // } + // + // @Test + // public void emptyLeaseUseNoAvailability() { + // LeaseImpl empty = LeaseImpl.empty(); + // boolean success = empty.use(); + // assertFalse(success); + // Assertions.assertEquals(0.0, empty.availability(), 1e-5); + // } + // + // @Test + // public void leaseAvailability() { + // LeaseImpl lease = LeaseImpl.create(2, 100, Unpooled.EMPTY_BUFFER); + // Assertions.assertEquals(1.0, lease.availability(), 1e-5); + // } + // + // @Test + // public void leaseUseDecreasesAvailability() { + // LeaseImpl lease = LeaseImpl.create(30_000, 2, Unpooled.EMPTY_BUFFER); + // boolean success = lease.use(); + // Assertions.assertTrue(success); + // Assertions.assertEquals(0.5, lease.availability(), 1e-5); + // Assertions.assertTrue(lease.isValid()); + // success = lease.use(); + // Assertions.assertTrue(success); + // Assertions.assertEquals(0.0, lease.availability(), 1e-5); + // Assertions.assertFalse(lease.isValid()); + // Assertions.assertEquals(0, lease.getAllowedRequests()); + // success = lease.use(); + // Assertions.assertFalse(success); + // } + // + // @Test + // public void leaseTimeout() { + // int numberOfRequests = 1; + // LeaseImpl lease = LeaseImpl.create(1, numberOfRequests, Unpooled.EMPTY_BUFFER); + // Mono.delay(Duration.ofMillis(100)).block(); + // boolean success = lease.use(); + // Assertions.assertFalse(success); + // Assertions.assertTrue(lease.isExpired()); + // Assertions.assertEquals(numberOfRequests, lease.getAllowedRequests()); + // Assertions.assertFalse(lease.isValid()); + // } + // + // @Test + // public void useLeaseChangesAllowedRequests() { + // int numberOfRequests = 2; + // LeaseImpl lease = LeaseImpl.create(30_000, numberOfRequests, Unpooled.EMPTY_BUFFER); + // lease.use(); + // assertEquals(numberOfRequests - 1, lease.getAllowedRequests()); + // } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java new file mode 100644 index 000000000..a35e89391 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceRSocketClientTest.java @@ -0,0 +1,94 @@ +package io.rsocket.loadbalance; + +import static java.util.Collections.singletonList; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +@ExtendWith(MockitoExtension.class) +class LoadbalanceRSocketClientTest { + + @Mock private ClientTransport clientTransport; + @Mock private RSocketConnector rSocketConnector; + + public static final Duration SHORT_DURATION = Duration.ofMillis(25); + public static final Duration LONG_DURATION = Duration.ofMillis(75); + + private static final Publisher SOURCE = + Flux.interval(SHORT_DURATION) + .onBackpressureBuffer() + .map(String::valueOf) + .map(DefaultPayload::create); + + private static final Mono PROGRESSING_HANDLER = + Mono.just( + new RSocket() { + private final AtomicInteger i = new AtomicInteger(); + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .delayElements(SHORT_DURATION) + .map(Payload::getDataUtf8) + .map(DefaultPayload::create) + .take(i.incrementAndGet()); + } + }); + + @Test + void testChannelReconnection() { + when(rSocketConnector.connect(clientTransport)).thenReturn(PROGRESSING_HANDLER); + + RSocketClient client = + LoadbalanceRSocketClient.create( + rSocketConnector, + Mono.just(singletonList(LoadbalanceTarget.from("key", clientTransport)))); + + Publisher result = + client + .requestChannel(SOURCE) + .repeatWhen(longFlux -> longFlux.delayElements(LONG_DURATION).take(5)) + .map(Payload::getDataUtf8) + .log(); + + StepVerifier.create(result) + .expectSubscription() + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("3")) + .assertNext(s -> assertThat(s).isEqualTo("0")) + .assertNext(s -> assertThat(s).isEqualTo("1")) + .assertNext(s -> assertThat(s).isEqualTo("2")) + .assertNext(s -> assertThat(s).isEqualTo("3")) + .assertNext(s -> assertThat(s).isEqualTo("4")) + .verifyComplete(); + + verify(rSocketConnector).connect(clientTransport); + verifyNoMoreInteractions(rSocketConnector, clientTransport); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java new file mode 100644 index 000000000..c1b509297 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/LoadbalanceTest.java @@ -0,0 +1,470 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.internal.subscriber.AssertSubscriber; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import io.rsocket.util.RSocketProxy; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; +import reactor.test.publisher.TestPublisher; +import reactor.test.util.RaceTestUtils; +import reactor.util.context.Context; + +public class LoadbalanceTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void shouldDeliverAllTheRequestsWithRoundRobinStrategy() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = new TestClientTransport(); + final RSocket rSocket = + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter.incrementAndGet(); + return Mono.empty(); + } + }; + + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(new TestRSocket(rSocket))); + + final List collectionOfDestination1 = + Collections.singletonList(LoadbalanceTarget.from("1", mockTransport)); + final List collectionOfDestination2 = + Collections.singletonList(LoadbalanceTarget.from("2", mockTransport)); + final List collectionOfDestinations1And2 = + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), LoadbalanceTarget.from("2", mockTransport)); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final Sinks.Many> source = + Sinks.unsafe().many().unicast().onBackpressureError(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, source.asFlux(), new RoundRobinLoadbalanceStrategy()); + final Mono fnfSource = + Mono.defer(() -> rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)); + + RaceTestUtils.race( + () -> { + for (int j = 0; j < 1000; j++) { + fnfSource.subscribe(new RetrySubscriber(fnfSource)); + } + }, + () -> { + for (int j = 0; j < 100; j++) { + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + } + }); + + Assertions.assertThat(counter.get()).isEqualTo(1000); + counter.set(0); + } + } + + @Test + public void shouldDeliverAllTheRequestsWithWeightedStrategy() throws InterruptedException { + final AtomicInteger counter = new AtomicInteger(); + + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + + final LoadbalanceTarget target1 = LoadbalanceTarget.from("1", mockTransport1); + final LoadbalanceTarget target2 = LoadbalanceTarget.from("2", mockTransport2); + + final WeightedRSocket weightedRSocket1 = new WeightedRSocket(counter); + final WeightedRSocket weightedRSocket2 = new WeightedRSocket(counter); + + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + Mockito.when(rSocketConnectorMock.connect(mockTransport1)) + .then(im -> Mono.just(new TestRSocket(weightedRSocket1))); + Mockito.when(rSocketConnectorMock.connect(mockTransport2)) + .then(im -> Mono.just(new TestRSocket(weightedRSocket2))); + final List collectionOfDestination1 = Collections.singletonList(target1); + final List collectionOfDestination2 = Collections.singletonList(target2); + final List collectionOfDestinations1And2 = Arrays.asList(target1, target2); + + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final Sinks.Many> source = + Sinks.unsafe().many().unicast().onBackpressureError(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source.asFlux(), + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver( + rsocket -> { + if (rsocket instanceof TestRSocket) { + return (WeightedRSocket) ((TestRSocket) rsocket).source(); + } + return ((PooledRSocket) rsocket).target() == target1 + ? weightedRSocket1 + : weightedRSocket2; + }) + .build()); + final Mono fnfSource = + Mono.defer(() -> rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)); + + RaceTestUtils.race( + () -> { + for (int j = 0; j < 1000; j++) { + fnfSource.subscribe(new RetrySubscriber(fnfSource)); + } + }, + () -> { + for (int j = 0; j < 100; j++) { + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination1, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(Collections.emptyList(), Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestination2, Sinks.EmitFailureHandler.FAIL_FAST); + source.emitNext(collectionOfDestinations1And2, Sinks.EmitFailureHandler.FAIL_FAST); + } + }); + + Assertions.assertThat(counter.get()).isEqualTo(1000); + counter.set(0); + } + } + + @Test + public void ensureRSocketIsCleanedFromThePoolIfSourceRSocketIsDisposed() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + final TestRSocket testRSocket = + new TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter.incrementAndGet(); + return Mono.empty(); + } + }); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.delay(Duration.ofMillis(200)).map(__ -> testRSocket)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport))); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + testRSocket.dispose(); + + Assertions.assertThatThrownBy( + () -> + rSocketPool + .select() + .fireAndForget(EmptyPayload.INSTANCE) + .block(Duration.ofSeconds(2))) + .isExactlyInstanceOf(IllegalStateException.class) + .hasMessage("Timeout on blocking read for 2000000000 NANOSECONDS"); + + Assertions.assertThat(counter.get()).isOne(); + } + + @Test + public void ensureContextIsPropagatedCorrectlyForRequestChannel() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.delay(Duration.ofMillis(200)) + .map( + __ -> + new TestRSocket( + new RSocket() { + @Override + public Flux requestChannel(Publisher source) { + counter.incrementAndGet(); + return Flux.from(source); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + // check that context is propagated when there is no rsocket + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .then( + () -> + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport)))) + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport))); + // check that context is propagated when there is an RSocket but it is unresolved + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + // check that context is propagated when there is an RSocket and it is resolved + StepVerifier.create( + rSocketPool + .select() + .requestChannel( + Flux.deferContextual( + cv -> { + if (cv.hasKey("test") && cv.get("test").equals("test")) { + return Flux.just(EmptyPayload.INSTANCE); + } else { + return Flux.error( + new IllegalStateException("Expected context to be propagated")); + } + })) + .contextWrite(Context.of("test", "test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(Duration.ofSeconds(2)); + + Assertions.assertThat(counter.get()).isEqualTo(3); + } + + @Test + public void shouldNotifyOnCloseWhenAllTheActiveSubscribersAreClosed() { + final AtomicInteger counter = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Sinks.Empty onCloseSocket1 = Sinks.empty(); + Sinks.Empty onCloseSocket2 = Sinks.empty(); + + RSocket socket1 = + new RSocket() { + @Override + public Mono onClose() { + return onCloseSocket1.asMono(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + }; + RSocket socket2 = + new RSocket() { + @Override + public Mono onClose() { + return onCloseSocket2.asMono(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + }; + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(socket1)) + .then(im -> Mono.just(socket2)) + .then(im -> Mono.never().doOnCancel(() -> counter.incrementAndGet())); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport), + LoadbalanceTarget.from("3", mockTransport))); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + StepVerifier.create(rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectComplete() + .verify(Duration.ofSeconds(2)); + + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + + rSocketPool.dispose(); + + AssertSubscriber onCloseSubscriber = + rSocketPool.onClose().subscribeWith(AssertSubscriber.create()); + + onCloseSubscriber.assertNotTerminated(); + + onCloseSocket1.tryEmitEmpty(); + + onCloseSubscriber.assertNotTerminated(); + + onCloseSocket2.tryEmitEmpty(); + + onCloseSubscriber.assertTerminated().assertComplete(); + + Assertions.assertThat(counter.get()).isOne(); + } + + static class TestRSocket extends RSocketProxy { + + final Sinks.Empty sink = Sinks.empty(); + + public TestRSocket(RSocket rSocket) { + super(rSocket); + } + + @Override + public Mono onClose() { + return sink.asMono(); + } + + @Override + public void dispose() { + sink.tryEmitEmpty(); + } + + public RSocket source() { + return source; + } + } + + private static class WeightedRSocket extends BaseWeightedStats implements RSocket { + + private final AtomicInteger counter; + + public WeightedRSocket(AtomicInteger counter) { + this.counter = counter; + } + + @Override + public Mono fireAndForget(Payload payload) { + final long startTime = startRequest(); + counter.incrementAndGet(); + return Mono.empty() + .doFinally( + (__) -> { + final long stopTime = stopRequest(startTime); + record(stopTime - startTime); + }); + } + } + + static class RetrySubscriber implements CoreSubscriber { + + final Publisher source; + + private RetrySubscriber(Publisher source) { + this.source = source; + } + + @Override + public void onSubscribe(Subscription s) { + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(Void unused) {} + + @Override + public void onError(Throwable t) { + source.subscribe(this); + } + + @Override + public void onComplete() {} + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java new file mode 100644 index 000000000..e43068dbd --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/RoundRobinLoadbalanceStrategyTest.java @@ -0,0 +1,170 @@ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.EmptyPayload; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.test.publisher.TestPublisher; + +public class RoundRobinLoadbalanceStrategyTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void shouldDeliverValuesProportionally() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport))); + + Assertions.assertThat(counter1.get()).isCloseTo(500, Offset.offset(1)); + Assertions.assertThat(counter2.get()).isCloseTo(500, Offset.offset(1)); + } + + @Test + public void shouldDeliverValuesToNewlyConnectedSockets() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then( + im -> + Mono.just( + new LoadbalanceTest.TestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + if (im.getArgument(0) == mockTransport1) { + counter1.incrementAndGet(); + } else { + counter2.incrementAndGet(); + } + return Mono.empty(); + } + }))); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool(rSocketConnectorMock, source, new RoundRobinLoadbalanceStrategy()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()).isCloseTo(RaceTestConstants.REPEATS, Offset.offset(1)); + + source.next(Collections.emptyList()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2 + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS / 2, Offset.offset(1)); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport1))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2 + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS + RaceTestConstants.REPEATS / 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 3, Offset.offset(1)); + Assertions.assertThat(counter2.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java b/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java new file mode 100644 index 000000000..8cc254cbb --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/loadbalance/WeightedLoadbalanceStrategyTest.java @@ -0,0 +1,254 @@ +package io.rsocket.loadbalance; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RaceTestConstants; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.ClientTransport; +import io.rsocket.util.Clock; +import io.rsocket.util.EmptyPayload; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.assertj.core.data.Offset; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.publisher.TestPublisher; + +public class WeightedLoadbalanceStrategyTest { + + @BeforeEach + void setUp() { + Hooks.onErrorDropped((__) -> {}); + } + + @AfterAll + static void afterAll() { + Hooks.resetOnErrorDropped(); + } + + @Test + public void allRequestsShouldGoToTheSocketWithHigherWeight() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final WeightedTestRSocket rSocket1 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket2 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }); + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(rSocket1)) + .then(im -> Mono.just(rSocket2)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source, + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver(r -> r instanceof WeightedStats ? (WeightedStats) r : null) + .build()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport), + LoadbalanceTarget.from("2", mockTransport))); + + Assertions.assertThat(counter1.get()) + .describedAs("c1=" + counter1.get() + " c2=" + counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS, Offset.offset(Math.round(RaceTestConstants.REPEATS * 0.1f))); + Assertions.assertThat(counter2.get()) + .describedAs("c1=" + counter1.get() + " c2=" + counter2.get()) + .isCloseTo(0, Offset.offset(Math.round(RaceTestConstants.REPEATS * 0.1f))); + } + + @Test + public void shouldDeliverValuesToTheSocketWithTheHighestCalculatedWeight() { + final AtomicInteger counter1 = new AtomicInteger(); + final AtomicInteger counter2 = new AtomicInteger(); + final ClientTransport mockTransport1 = Mockito.mock(ClientTransport.class); + final ClientTransport mockTransport2 = Mockito.mock(ClientTransport.class); + final RSocketConnector rSocketConnectorMock = Mockito.mock(RSocketConnector.class); + final WeightedTestRSocket rSocket1 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket2 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter1.incrementAndGet(); + return Mono.empty(); + } + }); + final WeightedTestRSocket rSocket3 = + new WeightedTestRSocket( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + counter2.incrementAndGet(); + return Mono.empty(); + } + }); + + Mockito.when(rSocketConnectorMock.connect(Mockito.any(ClientTransport.class))) + .then(im -> Mono.just(rSocket1)) + .then(im -> Mono.just(rSocket2)) + .then(im -> Mono.just(rSocket3)); + + final TestPublisher> source = TestPublisher.create(); + final RSocketPool rSocketPool = + new RSocketPool( + rSocketConnectorMock, + source, + WeightedLoadbalanceStrategy.builder() + .weightedStatsResolver(r -> r instanceof WeightedStats ? (WeightedStats) r : null) + .build()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()).isCloseTo(RaceTestConstants.REPEATS, Offset.offset(1)); + + source.next(Collections.emptyList()); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + rSocket1.updateAvailability(0.0); + + source.next(Collections.singletonList(LoadbalanceTarget.from("1", mockTransport1))); + + Assertions.assertThat(counter1.get()) + .isCloseTo(RaceTestConstants.REPEATS * 2, Offset.offset(1)); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + final RSocket rSocket = rSocketPool.select(); + rSocket.fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 3 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo(0, Offset.offset(Math.round(RaceTestConstants.REPEATS * 3 * 0.1f))); + + rSocket2.updateAvailability(0.0); + + source.next(Collections.singletonList(LoadbalanceTarget.from("2", mockTransport1))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + rSocketPool.select().fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 4 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 4 * 0.1f))); + + source.next( + Arrays.asList( + LoadbalanceTarget.from("1", mockTransport1), + LoadbalanceTarget.from("2", mockTransport2))); + + for (int j = 0; j < RaceTestConstants.REPEATS; j++) { + final RSocket rSocket = rSocketPool.select(); + rSocket.fireAndForget(EmptyPayload.INSTANCE).subscribe(); + } + + Assertions.assertThat(counter1.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 3, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 5 * 0.1f))); + Assertions.assertThat(counter2.get()) + .isCloseTo( + RaceTestConstants.REPEATS * 2, + Offset.offset(Math.round(RaceTestConstants.REPEATS * 5 * 0.1f))); + } + + static class WeightedTestRSocket extends BaseWeightedStats implements RSocket { + + final Sinks.Empty sink = Sinks.empty(); + + final RSocket rSocket; + + public WeightedTestRSocket(RSocket rSocket) { + this.rSocket = rSocket; + } + + @Override + public Mono fireAndForget(Payload payload) { + startRequest(); + final long startTime = Clock.now(); + return this.rSocket + .fireAndForget(payload) + .doFinally( + __ -> { + stopRequest(startTime); + record(Clock.now() - startTime); + updateAvailability(1.0); + }); + } + + @Override + public Mono onClose() { + return sink.asMono(); + } + + @Override + public void dispose() { + sink.tryEmitEmpty(); + } + + public RSocket source() { + return rSocket; + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java new file mode 100644 index 000000000..58ab30021 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/AuthMetadataCodecTest.java @@ -0,0 +1,474 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class AuthMetadataCodecTest { + + public static final int AUTH_TYPE_ID_LENGTH = 1; + public static final int USER_NAME_BYTES_LENGTH = 2; + public static final String TEST_BEARER_TOKEN = + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyLCJpYXQxIjoxNTE2MjM5MDIyLCJpYXQyIjoxNTE2MjM5MDIyLCJpYXQzIjoxNTE2MjM5MDIyLCJpYXQ0IjoxNTE2MjM5MDIyfQ.ljYuH-GNyyhhLcx-rHMchRkGbNsR2_4aSxo8XjrYrSM"; + + @Test + void shouldCorrectlyEncodeData() { + String username = "test"; + String password = "tset1234"; + + int usernameLength = username.length(); + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataCodec.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData1() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataCodec.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + @Test + void shouldCorrectlyEncodeData2() { + String username = "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎1234567#4? "; + String password = "tset1234"; + + int usernameLength = username.getBytes(CharsetUtil.UTF_8).length; + int passwordLength = password.length(); + + ByteBuf byteBuf = + AuthMetadataCodec.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, username.toCharArray(), password.toCharArray()); + + byteBuf.markReaderIndex(); + checkSimpleAuthMetadataEncoding( + username, password, usernameLength, passwordLength, byteBuf.retain()); + byteBuf.resetReaderIndex(); + checkSimpleAuthMetadataEncodingUsingDecoders( + username, password, usernameLength, passwordLength, byteBuf); + } + + private static void checkSimpleAuthMetadataEncoding( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(byteBuf.readUnsignedByte() & ~0x80) + .isEqualTo(WellKnownAuthType.SIMPLE.getIdentifier()); + Assertions.assertThat(byteBuf.readUnsignedShort()).isEqualTo((short) usernameLength); + + Assertions.assertThat(byteBuf.readCharSequence(usernameLength, CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(byteBuf.readCharSequence(passwordLength, CharsetUtil.UTF_8)) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + private static void checkSimpleAuthMetadataEncodingUsingDecoders( + String username, String password, int usernameLength, int passwordLength, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(AUTH_TYPE_ID_LENGTH + USER_NAME_BYTES_LENGTH + usernameLength + passwordLength); + + Assertions.assertThat(AuthMetadataCodec.readWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.SIMPLE); + byteBuf.markReaderIndex(); + Assertions.assertThat(AuthMetadataCodec.readUsername(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(username); + Assertions.assertThat(AuthMetadataCodec.readPassword(byteBuf).toString(CharsetUtil.UTF_8)) + .isEqualTo(password); + byteBuf.resetReaderIndex(); + + Assertions.assertThat(new String(AuthMetadataCodec.readUsernameAsCharArray(byteBuf))) + .isEqualTo(username); + Assertions.assertThat(new String(AuthMetadataCodec.readPasswordAsCharArray(byteBuf))) + .isEqualTo(password); + + ReferenceCountUtil.release(byteBuf); + } + + @Test + void shouldThrowExceptionIfUsernameLengthExitsAllowedBounds() { + StringBuilder usernameBuilder = new StringBuilder(); + String usernamePart = + "𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎𠸏𠹷𠺝𠺢𠻗𠻹𠻺𠼭𠼮𠽌𠾴𠾼𠿪𡁜𡁯𡁵𡁶𡁻𡃁𡃉𡇙𢃇𢞵𢫕𢭃𢯊𢱑𢱕𢳂𢴈𢵌𢵧𢺳𣲷𤓓𤶸𤷪𥄫𦉘𦟌𦧲𦧺𧨾𨅝𨈇𨋢𨳊𨳍𨳒𩶘𠜎𠜱𠝹"; + for (int i = 0; i < 65535 / usernamePart.length(); i++) { + usernameBuilder.append(usernamePart); + } + String password = "tset1234"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeSimpleMetadata( + ByteBufAllocator.DEFAULT, + usernameBuilder.toString().toCharArray(), + password.toCharArray())) + .hasMessage( + "Username should be shorter than or equal to 65535 bytes length in UTF-8 encoding"); + } + + @Test + void shouldEncodeBearerMetadata() { + String testToken = TEST_BEARER_TOKEN; + + ByteBuf byteBuf = + AuthMetadataCodec.encodeBearerMetadata(ByteBufAllocator.DEFAULT, testToken.toCharArray()); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(testToken, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(testToken, byteBuf); + } + + private static void checkBearerAuthMetadataEncoding(String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat( + byteBuf.readUnsignedByte() & ~AuthMetadataCodec.STREAM_METADATA_KNOWN_MASK) + .isEqualTo(WellKnownAuthType.BEARER.getIdentifier()); + Assertions.assertThat(byteBuf.readSlice(byteBuf.capacity() - 1).toString(CharsetUtil.UTF_8)) + .isEqualTo(testToken); + } + + private static void checkBearerAuthMetadataEncodingUsingDecoders( + String testToken, ByteBuf byteBuf) { + Assertions.assertThat(byteBuf.capacity()) + .isEqualTo(testToken.getBytes(CharsetUtil.UTF_8).length + AUTH_TYPE_ID_LENGTH); + Assertions.assertThat(AuthMetadataCodec.isWellKnownAuthType(byteBuf)).isTrue(); + Assertions.assertThat(AuthMetadataCodec.readWellKnownAuthType(byteBuf)) + .isEqualTo(WellKnownAuthType.BEARER); + byteBuf.markReaderIndex(); + Assertions.assertThat(new String(AuthMetadataCodec.readBearerTokenAsCharArray(byteBuf))) + .isEqualTo(testToken); + byteBuf.resetReaderIndex(); + Assertions.assertThat( + AuthMetadataCodec.readPayload(byteBuf).toString(CharsetUtil.UTF_8).toString()) + .isEqualTo(testToken); + } + + @Test + void shouldEncodeCustomAuth() { + String payloadAsAText = "testsecuritybuffer"; + ByteBuf testSecurityPayload = + Unpooled.wrappedBuffer(payloadAsAText.getBytes(CharsetUtil.UTF_8)); + + String customAuthType = "myownauthtype"; + ByteBuf buffer = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload); + + checkCustomAuthMetadataEncoding(testSecurityPayload, customAuthType, buffer); + } + + private static void checkCustomAuthMetadataEncoding( + ByteBuf testSecurityPayload, String customAuthType, ByteBuf buffer) { + Assertions.assertThat(buffer.capacity()) + .isEqualTo(1 + customAuthType.length() + testSecurityPayload.capacity()); + Assertions.assertThat(buffer.readUnsignedByte()) + .isEqualTo((short) (customAuthType.length() - 1)); + Assertions.assertThat( + buffer.readCharSequence(customAuthType.length(), CharsetUtil.US_ASCII).toString()) + .isEqualTo(customAuthType); + Assertions.assertThat(buffer.readSlice(testSecurityPayload.capacity())) + .isEqualTo(testSecurityPayload); + + ReferenceCountUtil.release(buffer); + } + + @Test + void shouldThrowOnNonASCIIChars() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = "1234567#4? 𠜎𠜱𠝹𠱓𠱸𠲖𠳏𠳕𠴕𠵼𠵿𠸎"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage("custom auth type must be US_ASCII characters only"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + // 130 chars + String customAuthType = + "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldThrowOnOutOfAllowedSizeType1() { + ByteBuf testSecurityPayload = ByteBufAllocator.DEFAULT.buffer(); + String customAuthType = ""; + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, customAuthType, testSecurityPayload)) + .hasMessage( + "custom auth type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void shouldEncodeUsingWellKnownAuthType() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType1() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.SIMPLE, + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldEncodeUsingWellKnownAuthType2() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldThrowIfWellKnownAuthTypeIsUnsupportedOrUnknown() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, WellKnownAuthType.UNPARSEABLE_AUTH_TYPE, buffer)) + .hasMessage("only allowed AuthType should be used"); + + buffer.release(); + } + + @Test + void shouldCompressMetadata() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "simple", + ByteBufAllocator.DEFAULT.buffer().writeShort(1).writeByte('u').writeByte('p')); + + checkSimpleAuthMetadataEncoding("u", "p", 1, 1, byteBuf); + } + + @Test + void shouldCompressMetadata1() { + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, + "bearer", + Unpooled.copiedBuffer(TEST_BEARER_TOKEN, CharsetUtil.UTF_8)); + + byteBuf.markReaderIndex(); + checkBearerAuthMetadataEncoding(TEST_BEARER_TOKEN, byteBuf); + byteBuf.resetReaderIndex(); + checkBearerAuthMetadataEncodingUsingDecoders(TEST_BEARER_TOKEN, byteBuf); + } + + @Test + void shouldNotCompressMetadata() { + ByteBuf testMetadataPayload = + Unpooled.wrappedBuffer(TEST_BEARER_TOKEN.getBytes(CharsetUtil.UTF_8)); + String customAuthType = "testauthtype"; + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, customAuthType, testMetadataPayload); + + checkCustomAuthMetadataEncoding(testMetadataPayload, customAuthType, byteBuf); + } + + @Test + void shouldConfirmWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataCodec.isWellKnownAuthType(metadata)).isTrue(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldConfirmGivenMetadataIsNotAWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple/afafgafadf", Unpooled.EMPTY_BUFFER); + + int initialReaderIndex = metadata.readerIndex(); + + Assertions.assertThat(AuthMetadataCodec.isWellKnownAuthType(metadata)).isFalse(); + Assertions.assertThat(metadata.readerIndex()).isEqualTo(initialReaderIndex); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadSimpleWellKnownAuthType() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "simple", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.SIMPLE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType1() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadataWithCompression( + ByteBufAllocator.DEFAULT, "bearer", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.BEARER; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldReadSimpleWellKnownAuthType2() { + ByteBuf metadata = + ByteBufAllocator.DEFAULT + .buffer() + .writeByte(3 | AuthMetadataCodec.STREAM_METADATA_KNOWN_MASK); + WellKnownAuthType expectedType = WellKnownAuthType.UNKNOWN_RESERVED_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength() { + ByteBuf metadata = ByteBufAllocator.DEFAULT.buffer().writeByte(3); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldNotReadSimpleWellKnownAuthTypeIfEncodedLength1() { + ByteBuf metadata = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, "testmetadataauthtype", Unpooled.EMPTY_BUFFER); + WellKnownAuthType expectedType = WellKnownAuthType.UNPARSEABLE_AUTH_TYPE; + checkDecodeWellKnowAuthTypeCorrectly(metadata, expectedType); + } + + @Test + void shouldThrowExceptionIsNotEnoughReadableBytes() { + Assertions.assertThatThrownBy( + () -> AuthMetadataCodec.readWellKnownAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode Well Know Auth type. Not enough readable bytes"); + } + + private static void checkDecodeWellKnowAuthTypeCorrectly( + ByteBuf metadata, WellKnownAuthType expectedType) { + int initialReaderIndex = metadata.readerIndex(); + + WellKnownAuthType wellKnownAuthType = AuthMetadataCodec.readWellKnownAuthType(metadata); + + Assertions.assertThat(wellKnownAuthType).isEqualTo(expectedType); + Assertions.assertThat(metadata.readerIndex()) + .isNotEqualTo(initialReaderIndex) + .isEqualTo(initialReaderIndex + 1); + + ReferenceCountUtil.release(metadata); + } + + @Test + void shouldReadCustomEncodedAuthType() { + String testAuthType = "TestAuthType"; + ByteBuf byteBuf = + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, testAuthType, Unpooled.EMPTY_BUFFER); + checkDecodeCustomAuthTypeCorrectly(testAuthType, byteBuf); + } + + @Test + void shouldThrowExceptionOnEmptyMetadata() { + Assertions.assertThatThrownBy(() -> AuthMetadataCodec.readCustomAuthType(Unpooled.EMPTY_BUFFER)) + .hasMessage("Unable to decode custom Auth type. Not enough readable bytes"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_wellknowninstead() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.readCustomAuthType( + AuthMetadataCodec.encodeMetadata( + ByteBufAllocator.DEFAULT, + WellKnownAuthType.BEARER, + Unpooled.copiedBuffer(new byte[] {'a', 'b'})))) + .hasMessage("Unable to decode custom Auth type. Incorrect auth type length"); + } + + @Test + void shouldThrowExceptionOnMalformedMetadata_length() { + Assertions.assertThatThrownBy( + () -> + AuthMetadataCodec.readCustomAuthType( + ByteBufAllocator.DEFAULT.buffer().writeByte(127).writeChar('a').writeChar('b'))) + .hasMessage("Unable to decode custom Auth type. Malformed length or auth type string"); + } + + private static void checkDecodeCustomAuthTypeCorrectly(String testAuthType, ByteBuf byteBuf) { + int initialReaderIndex = byteBuf.readerIndex(); + + Assertions.assertThat(AuthMetadataCodec.readCustomAuthType(byteBuf).toString()) + .isEqualTo(testAuthType); + Assertions.assertThat(byteBuf.readerIndex()) + .isEqualTo(initialReaderIndex + testAuthType.length() + 1); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java new file mode 100644 index 000000000..a4e8fb2d8 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataCodecTest.java @@ -0,0 +1,558 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeAndContentBuffersSlices; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeIdFromMimeBuffer; +import static io.rsocket.metadata.CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer; +import static org.assertj.core.api.Assertions.*; + +import io.netty.buffer.*; +import io.netty.util.CharsetUtil; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.test.util.ByteBufUtils; +import io.rsocket.util.NumberUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +class CompositeMetadataCodecTest { + + final LeaksTrackingByteBufAllocator testAllocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + @AfterEach + void tearDownAndCheckForLeaks() { + testAllocator.assertHasNoLeaks(); + } + + static String byteToBitsString(byte b) { + return String.format("%8s", Integer.toBinaryString(b & 0xFF)).replace(' ', '0'); + } + + static String toHeaderBits(ByteBuf encoded) { + encoded.markReaderIndex(); + byte headerByte = encoded.readByte(); + String byteAsString = byteToBitsString(headerByte); + encoded.resetReaderIndex(); + return byteAsString; + } + // ==== + + @Test + void customMimeHeaderLatin1_encodingFails() { + String mimeNotAscii = "mime/typé"; + + assertThatIllegalArgumentException() + .isThrownBy( + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeNotAscii, 0)) + .withMessage("custom mime type must be US_ASCII characters only"); + } + + @Test + void customMimeHeaderLength0_encodingFails() { + assertThatIllegalArgumentException() + .isThrownBy(() -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, "", 0)) + .withMessage( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void customMimeHeaderLength127() { + StringBuilder builder = new StringBuilder(127); + for (int i = 0; i < 127; i++) { + builder.append('a'); + } + String mimeString = builder.toString(); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111110"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(127 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(127, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void customMimeHeaderLength128() { + StringBuilder builder = new StringBuilder(128); + for (int i = 0; i < 128; i++) { + builder.append('a'); + } + String mimeString = builder.toString(); + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("01111111"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(128 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(128, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void customMimeHeaderLength129_encodingFails() { + StringBuilder builder = new StringBuilder(129); + for (int i = 0; i < 129; i++) { + builder.append('a'); + } + + assertThatIllegalArgumentException() + .isThrownBy( + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, builder.toString(), 0)) + .withMessage( + "custom mime type must have a strictly positive length that fits on 7 unsigned bits, ie 1-128"); + } + + @Test + void customMimeHeaderLengthOne() { + String mimeString = "w"; + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000000"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()).as("mime length").isZero(); // encoded as actual length - 1 + + assertThat(header.readCharSequence(1, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void customMimeHeaderLengthTwo() { + String mimeString = "ww"; + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeString, 0); + + // remember actual length = encoded length + 1 + assertThat(toHeaderBits(encoded)).startsWith("0").isEqualTo("00000001"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isGreaterThan(1); + + assertThat((int) header.readByte()) + .as("mime length") + .isEqualTo(2 - 1); // encoded as actual length - 1 + + assertThat(header.readCharSequence(2, CharsetUtil.US_ASCII)) + .as("mime string") + .hasToString(mimeString); + + header.resetReaderIndex(); + assertThat(CompositeMetadataCodec.decodeMimeTypeFromMimeBuffer(header)) + .as("decoded mime string") + .hasToString(mimeString); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void customMimeHeaderUtf8_encodingFails() { + String mimeNotAscii = + "mime/tyࠒe"; // this is the SAMARITAN LETTER QUF u+0812 represented on 3 bytes + assertThatIllegalArgumentException() + .isThrownBy( + () -> CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mimeNotAscii, 0)) + .withMessage("custom mime type must be US_ASCII characters only"); + } + + @Test + void decodeEntryAtEndOfBuffer() { + ByteBuf fakeEntry = Unpooled.buffer(); + + assertThatIllegalArgumentException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryHasNoContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(0); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryTooShortForContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(1); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + NumberUtils.encodeUnsignedMedium(fakeEntry, 456); + fakeEntry.writeChar('w'); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeEntryTooShortForMimeLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(120); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeAndContentBuffersSlices(fakeEntry, 0, false)); + } + + @Test + void decodeIdMinusTwoWhenMoreThanOneByte() { + ByteBuf fakeIdBuffer = Unpooled.buffer(2); + fakeIdBuffer.writeInt(200); + + assertThat(decodeMimeIdFromMimeBuffer(fakeIdBuffer)) + .isEqualTo((WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier())); + } + + @Test + void decodeIdMinusTwoWhenZeroByte() { + ByteBuf fakeIdBuffer = Unpooled.buffer(0); + + assertThat(decodeMimeIdFromMimeBuffer(fakeIdBuffer)) + .isEqualTo((WellKnownMimeType.UNPARSEABLE_MIME_TYPE.getIdentifier())); + } + + @Test + void decodeStringNullIfLengthOne() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + fakeTypeBuffer.writeByte(1); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)); + } + + @Test + void decodeStringNullIfLengthZero() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + + assertThatIllegalStateException() + .isThrownBy(() -> decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)); + } + + @Test + void decodeTypeSkipsFirstByte() { + ByteBuf fakeTypeBuffer = Unpooled.buffer(2); + fakeTypeBuffer.writeByte(128); + fakeTypeBuffer.writeCharSequence("example", CharsetUtil.US_ASCII); + + assertThat(decodeMimeTypeFromMimeBuffer(fakeTypeBuffer)).hasToString("example"); + } + + @Test + void encodeMetadataCustomTypeDelegates() { + ByteBuf expected = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, "foo", 2); + + CompositeByteBuf test = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + test, testAllocator, "foo", ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); + } + + @Test + void encodeMetadataKnownTypeDelegates() { + ByteBuf expected = + CompositeMetadataCodec.encodeMetadataHeader( + testAllocator, WellKnownMimeType.APPLICATION_OCTET_STREAM.getIdentifier(), 2); + + CompositeByteBuf test = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + test, + testAllocator, + WellKnownMimeType.APPLICATION_OCTET_STREAM, + ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); + } + + @Test + void encodeMetadataReservedTypeDelegates() { + ByteBuf expected = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, (byte) 120, 2); + + CompositeByteBuf test = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + test, testAllocator, (byte) 120, ByteBufUtils.getRandomByteBuf(2)); + + assertThat((Iterable) test).hasSize(2).first().isEqualTo(expected); + test.release(); + expected.release(); + } + + @Test + void encodeTryCompressWithCompressableType() { + ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); + CompositeByteBuf target = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + target, testAllocator, WellKnownMimeType.APPLICATION_AVRO.getString(), metadata); + + assertThat(target.readableBytes()).as("readableBytes 1 + 3 + 2").isEqualTo(6); + target.release(); + } + + @Test + void encodeTryCompressWithCustomType() { + ByteBuf metadata = ByteBufUtils.getRandomByteBuf(2); + CompositeByteBuf target = testAllocator.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + target, testAllocator, "custom/example", metadata); + + assertThat(target.readableBytes()).as("readableBytes 1 + 14 + 3 + 2").isEqualTo(20); + target.release(); + } + + @Test + void hasEntry() { + WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; + + CompositeByteBuf buffer = + testAllocator + .compositeBuffer() + .addComponent( + true, + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0)) + .addComponent( + true, + CompositeMetadataCodec.encodeMetadataHeader( + testAllocator, mime.getIdentifier(), 0)); + + assertThat(CompositeMetadataCodec.hasEntry(buffer, 0)).isTrue(); + assertThat(CompositeMetadataCodec.hasEntry(buffer, 4)).isTrue(); + assertThat(CompositeMetadataCodec.hasEntry(buffer, 8)).isFalse(); + buffer.release(); + } + + @Test + void isWellKnownMimeType() { + ByteBuf wellKnown = Unpooled.buffer().writeByte(0); + assertThat(CompositeMetadataCodec.isWellKnownMimeType(wellKnown)).isTrue(); + + ByteBuf explicit = Unpooled.buffer().writeByte(2).writeChar('a'); + assertThat(CompositeMetadataCodec.isWellKnownMimeType(explicit)).isFalse(); + } + + @Test + void knownMimeHeader120_reserved() { + byte mime = (byte) 120; + ByteBuf encoded = CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime, 0); + + assertThat(mime) + .as("smoke test RESERVED_120 unsigned 7 bits representation") + .isEqualTo((byte) 0b01111000); + + assertThat(toHeaderBits(encoded)).startsWith("1").isEqualTo("11111000"); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("11111000"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)).as("decoded mime id").isEqualTo(mime); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void knownMimeHeader127_compositeMetadata() { + WellKnownMimeType mime = WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA; + assertThat(mime.getIdentifier()) + .as("smoke test COMPOSITE unsigned 7 bits representation") + .isEqualTo((byte) 127) + .isEqualTo((byte) 0b01111111); + ByteBuf encoded = + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0); + + assertThat(toHeaderBits(encoded)) + .startsWith("1") + .isEqualTo("11111111") + .isEqualTo(byteToBitsString(mime.getIdentifier()).replaceFirst("0", "1")); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("11111111"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)) + .as("decoded mime id") + .isEqualTo(mime.getIdentifier()); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void knownMimeHeaderZero_avro() { + WellKnownMimeType mime = WellKnownMimeType.APPLICATION_AVRO; + assertThat(mime.getIdentifier()) + .as("smoke test AVRO unsigned 7 bits representation") + .isEqualTo((byte) 0) + .isEqualTo((byte) 0b00000000); + ByteBuf encoded = + CompositeMetadataCodec.encodeMetadataHeader(testAllocator, mime.getIdentifier(), 0); + + assertThat(toHeaderBits(encoded)) + .startsWith("1") + .isEqualTo("10000000") + .isEqualTo(byteToBitsString(mime.getIdentifier()).replaceFirst("0", "1")); + + final ByteBuf[] byteBufs = decodeMimeAndContentBuffersSlices(encoded, 0, false); + assertThat(byteBufs).hasSize(2).doesNotContainNull(); + + ByteBuf header = byteBufs[0]; + ByteBuf content = byteBufs[1]; + header.markReaderIndex(); + + assertThat(header.readableBytes()).as("metadata header size").isOne(); + + assertThat(byteToBitsString(header.readByte())) + .as("header bit representation") + .isEqualTo("10000000"); + + header.resetReaderIndex(); + assertThat(decodeMimeIdFromMimeBuffer(header)) + .as("decoded mime id") + .isEqualTo(mime.getIdentifier()); + + assertThat(content.readableBytes()).as("no metadata content").isZero(); + encoded.release(); + } + + @Test + void encodeCustomHeaderAsciiCheckSkipsFirstByte() { + final ByteBuf badBuf = Unpooled.copiedBuffer("é00000000000", CharsetUtil.UTF_8); + badBuf.writerIndex(0); + assertThat(badBuf.readerIndex()).isZero(); + + ByteBufAllocator allocator = + new AbstractByteBufAllocator() { + @Override + public boolean isDirectBufferPooled() { + return false; + } + + @Override + protected ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity) { + return badBuf; + } + + @Override + protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { + return badBuf; + } + }; + + assertThatCode(() -> CompositeMetadataCodec.encodeMetadataHeader(allocator, "custom/type", 0)) + .doesNotThrowAnyException(); + + assertThat(badBuf.readByte()).isEqualTo((byte) 10); + assertThat(badBuf.readCharSequence(11, CharsetUtil.UTF_8)).hasToString("custom/type"); + assertThat(badBuf.readUnsignedMedium()).isEqualTo(0); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java new file mode 100644 index 000000000..0b81ab4b0 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/CompositeMetadataTest.java @@ -0,0 +1,178 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.rsocket.metadata.CompositeMetadata.Entry; +import io.rsocket.metadata.CompositeMetadata.ReservedMimeTypeEntry; +import io.rsocket.metadata.CompositeMetadata.WellKnownMimeTypeEntry; +import io.rsocket.test.util.ByteBufUtils; +import io.rsocket.util.NumberUtils; +import java.util.Iterator; +import java.util.Spliterator; +import org.junit.jupiter.api.Test; + +class CompositeMetadataTest { + + @Test + void decodeEntryHasNoContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(0); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeEntryOnDoneBufferThrowsIllegalArgument() { + ByteBuf fakeBuffer = ByteBufUtils.getRandomByteBuf(0); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeBuffer, false); + + assertThatIllegalArgumentException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("entry index 0 is larger than buffer size"); + } + + @Test + void decodeEntryTooShortForContentLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(1); + fakeEntry.writeCharSequence("w", CharsetUtil.US_ASCII); + NumberUtils.encodeUnsignedMedium(fakeEntry, 456); + fakeEntry.writeChar('w'); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeEntryTooShortForMimeLength() { + ByteBuf fakeEntry = Unpooled.buffer(); + fakeEntry.writeByte(120); + CompositeMetadata compositeMetadata = new CompositeMetadata(fakeEntry, false); + + assertThatIllegalStateException() + .isThrownBy(() -> compositeMetadata.iterator().next()) + .withMessage("metadata is malformed"); + } + + @Test + void decodeThreeEntries() { + // metadata 1: well known + WellKnownMimeType mimeType1 = WellKnownMimeType.APPLICATION_PDF; + ByteBuf metadata1 = Unpooled.buffer(); + metadata1.writeCharSequence("abcdefghijkl", CharsetUtil.UTF_8); + + // metadata 2: custom + String mimeType2 = "application/custom"; + ByteBuf metadata2 = Unpooled.buffer(); + metadata2.writeChar('E'); + metadata2.writeChar('∑'); + metadata2.writeChar('é'); + metadata2.writeBoolean(true); + metadata2.writeChar('W'); + + // metadata 3: reserved but unknown + byte reserved = 120; + assertThat(WellKnownMimeType.fromIdentifier(reserved)) + .as("ensure UNKNOWN RESERVED used in test") + .isSameAs(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE); + ByteBuf metadata3 = Unpooled.buffer(); + metadata3.writeByte(88); + + CompositeByteBuf compositeMetadataBuffer = ByteBufAllocator.DEFAULT.compositeBuffer(); + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType1, metadata1); + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, mimeType2, metadata2); + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadataBuffer, ByteBufAllocator.DEFAULT, reserved, metadata3); + + Iterator iterator = new CompositeMetadata(compositeMetadataBuffer, true).iterator(); + + assertThat(iterator.next()) + .as("entry1") + .isNotNull() + .satisfies( + e -> + assertThat(e.getMimeType()).as("entry1 mime type").isEqualTo(mimeType1.getString())) + .satisfies( + e -> + assertThat(((WellKnownMimeTypeEntry) e).getType()) + .as("entry1 mime id") + .isEqualTo(WellKnownMimeType.APPLICATION_PDF)) + .satisfies( + e -> + assertThat(e.getContent().toString(CharsetUtil.UTF_8)) + .as("entry1 decoded") + .isEqualTo("abcdefghijkl")); + + assertThat(iterator.next()) + .as("entry2") + .isNotNull() + .satisfies(e -> assertThat(e.getMimeType()).as("entry2 mime type").isEqualTo(mimeType2)) + .satisfies( + e -> assertThat(e.getContent()).as("entry2 decoded").isEqualByComparingTo(metadata2)); + + assertThat(iterator.next()) + .as("entry3") + .isNotNull() + .satisfies(e -> assertThat(e.getMimeType()).as("entry3 mime type").isNull()) + .satisfies( + e -> + assertThat(((ReservedMimeTypeEntry) e).getType()) + .as("entry3 mime id") + .isEqualTo(reserved)) + .satisfies( + e -> assertThat(e.getContent()).as("entry3 decoded").isEqualByComparingTo(metadata3)); + + assertThat(iterator.hasNext()).as("has no more than 3 entries").isFalse(); + } + + @Test + void streamIsNotParallel() { + final CompositeMetadata metadata = + new CompositeMetadata(ByteBufUtils.getRandomByteBuf(5), false); + + assertThat(metadata.stream().isParallel()).as("isParallel").isFalse(); + } + + @Test + void streamSpliteratorCharacteristics() { + final CompositeMetadata metadata = + new CompositeMetadata(ByteBufUtils.getRandomByteBuf(5), false); + + assertThat(metadata.stream().spliterator()) + .matches(s -> s.hasCharacteristics(Spliterator.ORDERED), "ORDERED") + .matches(s -> s.hasCharacteristics(Spliterator.DISTINCT), "DISTINCT") + .matches(s -> s.hasCharacteristics(Spliterator.NONNULL), "NONNULL") + .matches(s -> !s.hasCharacteristics(Spliterator.SIZED), "not SIZED"); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java new file mode 100644 index 000000000..5c8d40306 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/MimeTypeMetadataCodecTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import java.util.List; +import org.assertj.core.util.Lists; +import org.junit.jupiter.api.Test; + +/** Unit tests for {@link MimeTypeMetadataCodec}. */ +public class MimeTypeMetadataCodecTest { + + @Test + public void wellKnownMimeType() { + WellKnownMimeType mimeType = WellKnownMimeType.APPLICATION_HESSIAN; + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeType); + try { + List mimeTypes = MimeTypeMetadataCodec.decode(byteBuf); + + assertThat(mimeTypes.size()).isEqualTo(1); + assertThat(WellKnownMimeType.fromString(mimeTypes.get(0))).isEqualTo(mimeType); + } finally { + byteBuf.release(); + } + } + + @Test + public void customMimeType() { + String mimeType = "aaa/bb"; + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeType); + try { + List mimeTypes = MimeTypeMetadataCodec.decode(byteBuf); + + assertThat(mimeTypes.size()).isEqualTo(1); + assertThat(mimeTypes.get(0)).isEqualTo(mimeType); + } finally { + byteBuf.release(); + } + } + + @Test + public void multipleMimeTypes() { + List mimeTypes = Lists.newArrayList("aaa/bbb", "application/x-hessian"); + ByteBuf byteBuf = MimeTypeMetadataCodec.encode(ByteBufAllocator.DEFAULT, mimeTypes); + + try { + assertThat(MimeTypeMetadataCodec.decode(byteBuf)).isEqualTo(mimeTypes); + } finally { + byteBuf.release(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java new file mode 100644 index 000000000..b65ffafee --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/TaggingMetadataTest.java @@ -0,0 +1,47 @@ +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBufAllocator; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Test; + +/** + * Tagging metadata test + * + * @author linux_china + */ +public class TaggingMetadataTest { + private ByteBufAllocator byteBufAllocator = ByteBufAllocator.DEFAULT; + + @Test + public void testParseTags() { + List tags = + Arrays.asList( + "ws://localhost:8080/rsocket", String.join("", Collections.nCopies(129, "x"))); + TaggingMetadata taggingMetadata = + TaggingMetadataCodec.createTaggingMetadata( + byteBufAllocator, "message/x.rsocket.routing.v0", tags); + TaggingMetadata taggingMetadataCopy = + new TaggingMetadata("message/x.rsocket.routing.v0", taggingMetadata.getContent()); + assertThat(tags) + .containsExactlyElementsOf(taggingMetadataCopy.stream().collect(Collectors.toList())); + } + + @Test + public void testEmptyTagAndOverLengthTag() { + List tags = + Arrays.asList( + "ws://localhost:8080/rsocket", "", String.join("", Collections.nCopies(256, "x"))); + TaggingMetadata taggingMetadata = + TaggingMetadataCodec.createTaggingMetadata( + byteBufAllocator, "message/x.rsocket.routing.v0", tags); + TaggingMetadata taggingMetadataCopy = + new TaggingMetadata("message/x.rsocket.routing.v0", taggingMetadata.getContent()); + assertThat(tags.subList(0, 1)) + .containsExactlyElementsOf(taggingMetadataCopy.stream().collect(Collectors.toList())); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java new file mode 100644 index 000000000..cb8478c13 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/TracingMetadataCodecTest.java @@ -0,0 +1,209 @@ +package io.rsocket.metadata; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCounted; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +public class TracingMetadataCodecTest { + + private static Stream flags() { + return Stream.of(TracingMetadataCodec.Flags.values()); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeEmptyTrace(TracingMetadataCodec.Flags expectedFlag) { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = TracingMetadataCodec.encodeEmpty(allocator, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(TracingMetadata::isEmpty) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace64WithParent(TracingMetadataCodec.Flags expectedFlag) { + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + long parentId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode64(allocator, traceId, spanId, parentId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == 0) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> tm.hasParent()) + .matches(tm -> tm.parentId() == parentId) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace64(TracingMetadataCodec.Flags expectedFlag) { + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = TracingMetadataCodec.encode64(allocator, traceId, spanId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == 0) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> !tm.hasParent()) + .matches(tm -> tm.parentId() == 0) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace128WithParent(TracingMetadataCodec.Flags expectedFlag) { + long traceIdHighLocal; + do { + traceIdHighLocal = ThreadLocalRandom.current().nextLong(); + + } while (traceIdHighLocal == 0); + long traceIdHigh = traceIdHighLocal; + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + long parentId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode128( + allocator, traceIdHigh, traceId, spanId, parentId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == traceIdHigh) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> tm.hasParent()) + .matches(tm -> tm.parentId() == parentId) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } + + @ParameterizedTest + @MethodSource("flags") + public void shouldEncodeTrace128(TracingMetadataCodec.Flags expectedFlag) { + long traceIdHighLocal; + do { + traceIdHighLocal = ThreadLocalRandom.current().nextLong(); + + } while (traceIdHighLocal == 0); + long traceIdHigh = traceIdHighLocal; + long traceId = ThreadLocalRandom.current().nextLong(); + long spanId = ThreadLocalRandom.current().nextLong(); + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + ByteBuf byteBuf = + TracingMetadataCodec.encode128(allocator, traceIdHigh, traceId, spanId, expectedFlag); + + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(byteBuf); + + Assertions.assertThat(tracingMetadata) + .matches(metadata -> !metadata.isEmpty()) + .matches(tm -> tm.traceIdHigh() == traceIdHigh) + .matches(tm -> tm.traceId() == traceId) + .matches(tm -> tm.spanId() == spanId) + .matches(tm -> !tm.hasParent()) + .matches(tm -> tm.parentId() == 0) + .matches( + tm -> { + switch (expectedFlag) { + case UNDECIDED: + return !tm.isDecided(); + case NOT_SAMPLE: + return tm.isDecided() && !tm.isSampled(); + case SAMPLE: + return tm.isDecided() && tm.isSampled(); + case DEBUG: + return tm.isDecided() && tm.isDebug(); + } + return false; + }); + Assertions.assertThat(byteBuf).matches(ReferenceCounted::release); + allocator.assertHasNoLeaks(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java b/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java new file mode 100644 index 000000000..316aaf091 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/metadata/WellKnownMimeTypeTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.metadata; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +class WellKnownMimeTypeTest { + + @Test + void fromIdentifierGreaterThan127() { + assertThat(WellKnownMimeType.fromIdentifier(128)) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromIdentifierMatchFromMimeType() { + for (WellKnownMimeType mimeType : WellKnownMimeType.values()) { + if (mimeType == WellKnownMimeType.UNPARSEABLE_MIME_TYPE + || mimeType == WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE) { + continue; + } + assertThat(WellKnownMimeType.fromString(mimeType.toString())) + .as("mimeType string for " + mimeType.name()) + .isSameAs(mimeType); + + assertThat(WellKnownMimeType.fromIdentifier(mimeType.getIdentifier())) + .as("mimeType ID for " + mimeType.name()) + .isSameAs(mimeType); + } + } + + @Test + void fromIdentifierNegative() { + assertThat(WellKnownMimeType.fromIdentifier(-1)) + .isSameAs(WellKnownMimeType.fromIdentifier(-2)) + .isSameAs(WellKnownMimeType.fromIdentifier(-12)) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromIdentifierReserved() { + assertThat(WellKnownMimeType.fromIdentifier(120)) + .isSameAs(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE); + } + + @Test + void fromStringUnknown() { + assertThat(WellKnownMimeType.fromString("foo/bar")) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } + + @Test + void fromStringUnknownReservedStillReturnsUnparseable() { + assertThat( + WellKnownMimeType.fromString(WellKnownMimeType.UNKNOWN_RESERVED_MIME_TYPE.getString())) + .isSameAs(WellKnownMimeType.UNPARSEABLE_MIME_TYPE); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java new file mode 100644 index 000000000..9a19050f9 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/RequestInterceptorTest.java @@ -0,0 +1,790 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.FrameType; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; +import reactor.util.annotation.Nullable; + +public class RequestInterceptorTest { + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientRequesterSide(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheClientResponderSide(boolean errorOutcome) + throws InterruptedException { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + CountDownLatch latch = new CountDownLatch(1); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .bindNow(LocalServerTransport.create("test")); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forRequestsInResponder( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerRequesterSide(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .interceptors( + ir -> + ir.forRequestsInResponder( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void interceptorShouldBeInstalledProperlyOnTheServerResponderSide(boolean errorOutcome) + throws InterruptedException { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + CountDownLatch latch = new CountDownLatch(1); + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final Closeable closeable = + RSocketServer.create( + (setup, rSocket) -> + Mono.just(new RSocket() {}) + .doAfterTerminate( + () -> { + new Thread( + () -> { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel( + Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + latch.countDown(); + }) + .start(); + })) + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .bindNow(LocalServerTransport.create("test")); + final RSocket rSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + Assertions.assertThat(latch.await(1, TimeUnit.SECONDS)).isTrue(); + + testRequestInterceptor + .expectOnStart(2, FrameType.REQUEST_FNF) + .expectOnComplete(2) + .expectOnStart(4, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 4) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(6, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 6) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(8, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 8) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @Test + void ensuresExceptionInTheInterceptorIsHandledProperly() { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + throw new RuntimeException("testOnCancel"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + StepVerifier.create(rSocket.fireAndForget(DefaultPayload.create("test"))) + .expectSubscription() + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestResponse(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestStream(DefaultPayload.create("test"))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + + StepVerifier.create(rSocket.requestChannel(Flux.just(DefaultPayload.create("test")))) + .expectSubscription() + .expectNextCount(1) + .expectComplete() + .verify(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldSupportMultipleInterceptors(boolean errorOutcome) { + final LeaksTrackingByteBufAllocator byteBufAllocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "test"); + + final Closeable closeable = + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return errorOutcome + ? Mono.error(new RuntimeException("test")) + : Mono.just(payload); + } + + @Override + public Flux requestStream(Payload payload) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.just(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorOutcome + ? Flux.error(new RuntimeException("test")) + : Flux.from(payloads); + } + })) + .bindNow(LocalServerTransport.create("test")); + + final RequestInterceptor testRequestInterceptor1 = + new RequestInterceptor() { + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnStart"); + } + + @Override + public void onTerminate( + int streamId, FrameType requestType, @Nullable Throwable terminalSignal) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + throw new RuntimeException("testOnTerminate"); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + throw new RuntimeException("testOnReject"); + } + + @Override + public void dispose() {} + }; + final TestRequestInterceptor testRequestInterceptor = new TestRequestInterceptor(); + final TestRequestInterceptor testRequestInterceptor2 = new TestRequestInterceptor(); + final RSocket rSocket = + RSocketConnector.create() + .interceptors( + ir -> + ir.forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor) + .forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor1) + .forRequestsInRequester( + (Function) + (__) -> testRequestInterceptor2)) + .connect(LocalClientTransport.create("test", byteBufAllocator)) + .block(); + + try { + rSocket + .fireAndForget(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestResponse(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .block(); + + rSocket + .requestStream(DefaultPayload.create("test")) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + rSocket + .requestChannel(Flux.just(DefaultPayload.create("test"))) + .onErrorResume(__ -> Mono.empty()) + .blockLast(); + + testRequestInterceptor + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + + testRequestInterceptor2 + .expectOnStart(1, FrameType.REQUEST_FNF) + .expectOnComplete(1) + .expectOnStart(3, FrameType.REQUEST_RESPONSE) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 3) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(5, FrameType.REQUEST_STREAM) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 5) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectOnStart(7, FrameType.REQUEST_CHANNEL) + .assertNext( + e -> + Assertions.assertThat(e) + .hasFieldOrPropertyWithValue("streamId", 7) + .hasFieldOrPropertyWithValue( + "eventType", + errorOutcome + ? TestRequestInterceptor.EventType.ON_ERROR + : TestRequestInterceptor.EventType.ON_COMPLETE)) + .expectNothing(); + } finally { + rSocket.dispose(); + closeable.dispose(); + byteBufAllocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java new file mode 100644 index 000000000..8261b3f49 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/plugins/TestRequestInterceptor.java @@ -0,0 +1,142 @@ +package io.rsocket.plugins; + +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.internal.jctools.queues.MpscUnboundedArrayQueue; +import java.util.Queue; +import java.util.function.Consumer; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Condition; +import reactor.util.annotation.Nullable; + +public class TestRequestInterceptor implements RequestInterceptor { + + final Queue events = new MpscUnboundedArrayQueue<>(128); + + @Override + public void dispose() {} + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_START, streamId, requestType, null)); + } + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + events.add( + new Event( + t == null ? EventType.ON_COMPLETE : EventType.ON_ERROR, streamId, requestType, t)); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + events.add(new Event(EventType.ON_CANCEL, streamId, requestType, null)); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) { + events.add(new Event(EventType.ON_REJECT, -1, requestType, rejectionReason)); + } + + public TestRequestInterceptor expectOnStart(int streamId, FrameType requestType) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_START) + .hasFieldOrPropertyWithValue("streamId", streamId) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectOnComplete(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_COMPLETE) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnError(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_ERROR) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor expectOnCancel(int streamId) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_CANCEL) + .hasFieldOrPropertyWithValue("streamId", streamId); + + return this; + } + + public TestRequestInterceptor assertNext(Consumer consumer) { + final Event event = events.poll(); + Assertions.assertThat(event).isNotNull(); + + consumer.accept(event); + + return this; + } + + public TestRequestInterceptor expectOnReject(FrameType requestType, Throwable rejectionReason) { + final Event event = events.poll(); + + Assertions.assertThat(event) + .hasFieldOrPropertyWithValue("eventType", EventType.ON_REJECT) + .has( + new Condition<>( + e -> { + Assertions.assertThat(e.error) + .isExactlyInstanceOf(rejectionReason.getClass()) + .hasMessage(rejectionReason.getMessage()) + .hasCause(rejectionReason.getCause()); + return true; + }, + "Has rejection reason which matches to %s", + rejectionReason)) + .hasFieldOrPropertyWithValue("requestType", requestType); + + return this; + } + + public TestRequestInterceptor expectNothing() { + final Event event = events.poll(); + + Assertions.assertThat(event).isNull(); + + return this; + } + + public static final class Event { + public final EventType eventType; + public final int streamId; + public final FrameType requestType; + public final Throwable error; + + Event(EventType eventType, int streamId, FrameType requestType, Throwable error) { + this.eventType = eventType; + this.streamId = streamId; + this.requestType = requestType; + this.error = error; + } + } + + public enum EventType { + ON_START, + ON_COMPLETE, + ON_ERROR, + ON_CANCEL, + ON_REJECT + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java new file mode 100644 index 000000000..8229bf42b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ClientRSocketSessionTest.java @@ -0,0 +1,470 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.exceptions.ConnectionCloseException; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.ResumeOkFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.test.util.TestClientTransport; +import io.rsocket.test.util.TestDuplexConnection; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; +import reactor.util.function.Tuples; +import reactor.util.retry.Retry; + +public class ClientRSocketSessionTest { + + @Test + void sessionTimeoutSmokeTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME_OK frame + transport + .testConnection() + .addToReceivedBuffer(ResumeOkFrameCodec.encode(transport.alloc(), 0)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + transport + .testConnection() + .addToReceivedBuffer( + ErrorFrameCodec.encode( + transport.alloc(), 0, new ConnectionCloseException("some message"))); + // connection should be closed because of the wrong first frame + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout is still in progress + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + // should obtain new connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_OK frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(transport.testConnection().isDisposed()).isTrue(); + + assertThat(session.isDisposed()).isTrue(); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectComplete().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void sessionTerminationOnWrongFrameTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME_OK frame + transport + .testConnection() + .addToReceivedBuffer(ResumeOkFrameCodec.encode(transport.alloc(), 0)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // Send KEEPALIVE frame as a first frame + transport + .testConnection() + .addToReceivedBuffer( + KeepAliveFrameCodec.encode(transport.alloc(), false, 0, Unpooled.EMPTY_BUFFER)); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(30)); + + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(transport.testConnection().isDisposed()).isTrue(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection + .onClose() + .as(StepVerifier::create) + .expectErrorMessage("RESUME_OK frame must be received before any others") + .verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldErrorWithNoRetriesOnErrorFrameTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time so new connection is received + virtualTimeScheduler.advanceTimeBy(Duration.ofMillis(1)); + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME) + .matches(ReferenceCounted::release); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send REJECTED_RESUME_ERROR frame + transport + .testConnection() + .addToReceivedBuffer( + ErrorFrameCodec.encode( + transport.alloc(), 0, new RejectedResumeException("failed resumption"))); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + resumableDuplexConnection + .onClose() + .as(StepVerifier::create) + .expectError(RejectedResumeException.class) + .verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldTerminateConnectionOnIllegalStateInKeepAliveFrame() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ClientRSocketSession session = + new ClientRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.connect().delaySubscription(Duration.ofMillis(1)), + c -> { + AtomicBoolean firstHandled = new AtomicBoolean(); + return ((TestDuplexConnection) c) + .receive() + .next() + .doOnNext(__ -> firstHandled.set(true)) + .doOnCancel( + () -> { + if (firstHandled.compareAndSet(false, true)) { + c.dispose(); + } + }) + .map(b -> Tuples.of(b, c)); + }, + framesStore, + Duration.ofMinutes(1), + Retry.indefinitely(), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + keepAliveSupport.resumeState(session); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + final ByteBuf keepAliveFrame = + KeepAliveFrameCodec.encode(transport.alloc(), false, 1529, Unpooled.EMPTY_BUFFER); + keepAliveSupport.receive(keepAliveFrame); + keepAliveFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectError().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java new file mode 100644 index 000000000..bba40d674 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/InMemoryResumeStoreTest.java @@ -0,0 +1,547 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.RaceTestConstants; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.internal.subscriber.AssertSubscriber; +import java.util.Arrays; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import reactor.core.Disposable; +import reactor.core.publisher.Hooks; +import reactor.test.util.RaceTestUtils; + +public class InMemoryResumeStoreTest { + + @Test + void saveNonResumableFrame() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeConnectionFrame(10); + final ByteBuf frame2 = fakeConnectionFrame(35); + + sender.onNext(frame1); + sender.onNext(frame2); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + assertThat(store.firstAvailableFramePosition).isZero(); + + assertSubscriber.assertValueCount(2).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + } + + @Test + void saveWithoutTailRemoval() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame = fakeResumableFrame(10); + + sender.onNext(frame); + + assertThat(store.cachedFrames.size()).isEqualTo(1); + assertThat(store.cacheSize).isEqualTo(frame.readableBytes()); + assertThat(store.firstAvailableFramePosition).isZero(); + + assertSubscriber.assertValueCount(1).values().forEach(ByteBuf::release); + + assertThat(frame.refCnt()).isOne(); + } + + @Test + void saveRemoveOneFromTail() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + final ByteBuf frame1 = fakeResumableFrame(20); + final ByteBuf frame2 = fakeResumableFrame(10); + + sender.onNext(frame1); + sender.onNext(frame2); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame2.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(frame1.readableBytes()); + + assertSubscriber.assertValueCount(2).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isOne(); + } + + @Test + void saveRemoveTwoFromTail() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(20); + + sender.onNext(frame1); + sender.onNext(frame2); + sender.onNext(frame3); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame3.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isOne(); + } + + @Test + void saveBiggerThanStore() { + final InMemoryResumableFramesStore store = inMemoryStore(25); + final UnboundedProcessor sender = new UnboundedProcessor(); + + store.saveFrames(sender).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + sender.onNext(frame1); + sender.onNext(frame2); + sender.onNext(frame3); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2, frame3)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @Test + void releaseFrames() { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + + store.releaseFrames(20); + + assertThat(store.cachedFrames.size()).isOne(); + assertThat(store.cacheSize).isEqualTo(frame3.readableBytes()); + assertThat(store.firstAvailableFramePosition).isEqualTo(size(frame1, frame2)); + + assertSubscriber.assertValueCount(3).values().forEach(ByteBuf::release); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isOne(); + } + + @Test + void receiveImpliedPosition() { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + ByteBuf frame1 = fakeResumableFrame(10); + ByteBuf frame2 = fakeResumableFrame(30); + + store.resumableFrameReceived(frame1); + store.resumableFrameReceived(frame2); + + assertThat(store.frameImpliedPosition()).isEqualTo(size(frame1, frame2)); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void ensuresCleansOnTerminal(boolean hasSubscriber) { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final AssertSubscriber assertSubscriber = + hasSubscriber ? store.resumeStream().subscribeWith(AssertSubscriber.create()) : null; + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + producer.onComplete(); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + + assertThat(producer.isDisposed()).isTrue(); + + if (hasSubscriber) { + assertSubscriber.assertValueCount(3).assertTerminated().values().forEach(ByteBuf::release); + } + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @Test + void ensuresCleansOnTerminalLateSubscriber() { + final InMemoryResumableFramesStore store = inMemoryStore(100); + + final UnboundedProcessor producer = new UnboundedProcessor(); + store.saveFrames(producer).subscribe(); + + final ByteBuf frame1 = fakeResumableFrame(10); + final ByteBuf frame2 = fakeResumableFrame(10); + final ByteBuf frame3 = fakeResumableFrame(30); + + producer.onNext(frame1); + producer.onNext(frame2); + producer.onNext(frame3); + producer.onComplete(); + + assertThat(store.cachedFrames.size()).isZero(); + assertThat(store.cacheSize).isZero(); + + assertThat(producer.isDisposed()).isTrue(); + + final AssertSubscriber assertSubscriber = + store.resumeStream().subscribeWith(AssertSubscriber.create()); + assertSubscriber.assertTerminated(); + + assertThat(frame1.refCnt()).isZero(); + assertThat(frame2.refCnt()).isZero(); + assertThat(frame3.refCnt()).isZero(); + } + + @ParameterizedTest(name = "Sending vs Reconnect Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void sendingVsReconnectRaceTest(boolean withLateSubscriber) { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + final BlockingQueue receivedFrames = new ArrayBlockingQueue<>(10); + final AtomicInteger receivedPosition = new AtomicInteger(); + + store.saveFrames(frames).subscribe(); + + final Consumer consumer = + f -> { + if (ResumableDuplexConnection.isResumableFrame(f)) { + receivedPosition.addAndGet(f.readableBytes()); + receivedFrames.offer(f); + return; + } + f.release(); + }; + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber ? null : store.resumeStream().subscribe(consumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer)); + } + + // disconnect + disposableReference.get().dispose(); + + while (InMemoryResumableFramesStore.isWorkInProgress(store.state)) { + // ignore + } + + // mimic RESUME_OK frame received + store.releaseFrames(receivedPosition.get()); + disposableReference.set(store.resumeStream().subscribe(consumer)); + + // disconnect + disposableReference.get().dispose(); + + while (InMemoryResumableFramesStore.isWorkInProgress(store.state)) { + // ignore + } + + // mimic RESUME_OK frame received + store.releaseFrames(receivedPosition.get()); + disposableReference.set(store.resumeStream().subscribe(consumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }); + + store.releaseFrames(receivedFrames.stream().mapToInt(ByteBuf::readableBytes).sum()); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + + assertThat(receivedFrames) + .hasSize(5) + .containsSequence(byteBuf1, byteBuf2, byteBuf3, byteBuf4, byteBuf5); + receivedFrames.forEach(ReferenceCounted::release); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } + + @ParameterizedTest( + name = "Sending vs Reconnect with incorrect position Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void incorrectReleaseFramesWithOnNextRaceTest(boolean withLateSubscriber) { + Hooks.onErrorDropped(t -> {}); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + + store.saveFrames(frames).subscribe(); + + final AtomicInteger terminationCnt = new AtomicInteger(); + final Consumer consumer = ReferenceCounted::release; + final Consumer errorConsumer = __ -> terminationCnt.incrementAndGet(); + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber + ? null + : store.resumeStream().subscribe(consumer, errorConsumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + } + // disconnect + disposableReference.get().dispose(); + + // mimic RESUME_OK frame received but with incorrect position + store.releaseFrames(25); + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + assertThat(disposableReference.get().isDisposed()).isTrue(); + assertThat(terminationCnt).hasValue(1); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + @ParameterizedTest( + name = + "Dispose vs Sending vs Reconnect with incorrect position Race Test. WithLateSubscriber[{0}]") + @ValueSource(booleans = {true, false}) + void incorrectReleaseFramesWithOnNextWithDisposeRaceTest(boolean withLateSubscriber) { + Hooks.onErrorDropped(t -> {}); + try { + for (int i = 0; i < RaceTestConstants.REPEATS; i++) { + final InMemoryResumableFramesStore store = inMemoryStore(Integer.MAX_VALUE); + final UnboundedProcessor frames = new UnboundedProcessor(); + + store.saveFrames(frames).subscribe(); + + final AtomicInteger terminationCnt = new AtomicInteger(); + final Consumer consumer = ReferenceCounted::release; + final Consumer errorConsumer = __ -> terminationCnt.incrementAndGet(); + final AtomicReference disposableReference = + new AtomicReference<>( + withLateSubscriber + ? null + : store.resumeStream().subscribe(consumer, errorConsumer)); + + final ByteBuf byteBuf1 = fakeResumableFrame(5); + final ByteBuf byteBuf11 = fakeConnectionFrame(5); + final ByteBuf byteBuf2 = fakeResumableFrame(6); + final ByteBuf byteBuf21 = fakeConnectionFrame(5); + final ByteBuf byteBuf3 = fakeResumableFrame(7); + final ByteBuf byteBuf31 = fakeConnectionFrame(5); + final ByteBuf byteBuf4 = fakeResumableFrame(8); + final ByteBuf byteBuf41 = fakeConnectionFrame(5); + final ByteBuf byteBuf5 = fakeResumableFrame(25); + final ByteBuf byteBuf51 = fakeConnectionFrame(35); + + RaceTestUtils.race( + () -> { + if (withLateSubscriber) { + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + } + // disconnect + disposableReference.get().dispose(); + + // mimic RESUME_OK frame received but with incorrect position + store.releaseFrames(25); + disposableReference.set(store.resumeStream().subscribe(consumer, errorConsumer)); + }, + () -> { + frames.onNext(byteBuf1); + frames.onNextPrioritized(byteBuf11); + frames.onNext(byteBuf2); + frames.onNext(byteBuf3); + frames.onNextPrioritized(byteBuf31); + frames.onNext(byteBuf4); + frames.onNext(byteBuf5); + }, + () -> { + frames.onNextPrioritized(byteBuf21); + frames.onNextPrioritized(byteBuf41); + frames.onNextPrioritized(byteBuf51); + }, + store::dispose); + + assertThat(store.cacheSize).isZero(); + assertThat(store.cachedFrames).isEmpty(); + assertThat(disposableReference.get().isDisposed()).isTrue(); + assertThat(terminationCnt).hasValueGreaterThanOrEqualTo(1).hasValueLessThanOrEqualTo(2); + + assertThat(byteBuf1.refCnt()).isZero(); + assertThat(byteBuf11.refCnt()).isZero(); + assertThat(byteBuf2.refCnt()).isZero(); + assertThat(byteBuf21.refCnt()).isZero(); + assertThat(byteBuf3.refCnt()).isZero(); + assertThat(byteBuf31.refCnt()).isZero(); + assertThat(byteBuf4.refCnt()).isZero(); + assertThat(byteBuf41.refCnt()).isZero(); + assertThat(byteBuf5.refCnt()).isZero(); + assertThat(byteBuf51.refCnt()).isZero(); + } + } finally { + Hooks.resetOnErrorDropped(); + } + } + + private int size(ByteBuf... byteBufs) { + return Arrays.stream(byteBufs).mapToInt(ByteBuf::readableBytes).sum(); + } + + private static InMemoryResumableFramesStore inMemoryStore(int size) { + return new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, size); + } + + private static ByteBuf fakeResumableFrame(int size) { + byte[] bytes = new byte[size]; + Arrays.fill(bytes, (byte) 7); + return Unpooled.wrappedBuffer(bytes); + } + + private static ByteBuf fakeConnectionFrame(int size) { + byte[] bytes = new byte[size]; + Arrays.fill(bytes, (byte) 0); + return Unpooled.wrappedBuffer(bytes); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java b/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java new file mode 100644 index 000000000..b5625bf8e --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/resume/ServerRSocketSessionTest.java @@ -0,0 +1,190 @@ +package io.rsocket.resume; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ReferenceCounted; +import io.rsocket.FrameAssert; +import io.rsocket.frame.FrameType; +import io.rsocket.frame.KeepAliveFrameCodec; +import io.rsocket.frame.ResumeFrameCodec; +import io.rsocket.keepalive.KeepAliveSupport; +import io.rsocket.test.util.TestClientTransport; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; +import reactor.test.scheduler.VirtualTimeScheduler; + +public class ServerRSocketSessionTest { + + @Test + void sessionTimeoutSmokeTest() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ServerRSocketSession session = + new ServerRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.testConnection(), + framesStore, + Duration.ofMinutes(1), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + // deactivate connection + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // resubscribe so a new connection is generated + transport.connect().subscribe(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(50)); + // timeout should not terminate current connection + assertThat(transport.testConnection().isDisposed()).isFalse(); + + // send RESUME frame + final ByteBuf resumeFrame = + ResumeFrameCodec.encode(transport.alloc(), Unpooled.EMPTY_BUFFER, 0, 0); + session.resumeWith(resumeFrame, transport.testConnection()); + resumeFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be terminated + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.RESUME_OK) + .matches(ReferenceCounted::release); + + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(15)); + + // disconnects for the second time + transport.testConnection().dispose(); + assertThat(transport.testConnection().isDisposed()).isTrue(); + // ensures timeout has been started + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + transport.connect().subscribe(); + + assertThat(transport.testConnection().isDisposed()).isFalse(); + // timeout should be still active since no RESUME_Ok frame has been received yet + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isFalse(); + + // advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(61)); + + final ByteBuf resumeFrame1 = + ResumeFrameCodec.encode(transport.alloc(), Unpooled.EMPTY_BUFFER, 0, 0); + session.resumeWith(resumeFrame1, transport.testConnection()); + resumeFrame1.release(); + + // should obtain new connection + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be still active since no RESUME_OK frame has been received yet + assertThat(session.s).isEqualTo(Operators.cancelledSubscription()); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectComplete().verify(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } + + @Test + void shouldTerminateConnectionOnIllegalStateInKeepAliveFrame() { + final VirtualTimeScheduler virtualTimeScheduler = VirtualTimeScheduler.getOrSet(); + try { + final TestClientTransport transport = new TestClientTransport(); + final InMemoryResumableFramesStore framesStore = + new InMemoryResumableFramesStore("test", Unpooled.EMPTY_BUFFER, 100); + + transport.connect().subscribe(); + + final ResumableDuplexConnection resumableDuplexConnection = + new ResumableDuplexConnection( + "test", Unpooled.EMPTY_BUFFER, transport.testConnection(), framesStore); + + resumableDuplexConnection.receive().subscribe(); + + final ServerRSocketSession session = + new ServerRSocketSession( + Unpooled.EMPTY_BUFFER, + resumableDuplexConnection, + transport.testConnection(), + framesStore, + Duration.ofMinutes(1), + true); + + final KeepAliveSupport.ClientKeepAliveSupport keepAliveSupport = + new KeepAliveSupport.ClientKeepAliveSupport(transport.alloc(), 1000000, 10000000); + keepAliveSupport.resumeState(session); + session.setKeepAliveSupport(keepAliveSupport); + + // connection is active. just advance time + virtualTimeScheduler.advanceTimeBy(Duration.ofSeconds(10)); + assertThat(session.s).isNull(); + assertThat(session.isDisposed()).isFalse(); + + final ByteBuf keepAliveFrame = + KeepAliveFrameCodec.encode(transport.alloc(), false, 1529, Unpooled.EMPTY_BUFFER); + keepAliveSupport.receive(keepAliveFrame); + keepAliveFrame.release(); + + assertThat(transport.testConnection().isDisposed()).isTrue(); + // timeout should be terminated + assertThat(session.s).isNotNull(); + assertThat(session.isDisposed()).isTrue(); + + FrameAssert.assertThat(transport.testConnection().pollFrame()) + .hasStreamIdZero() + .typeOf(FrameType.ERROR) + .matches(ReferenceCounted::release); + + resumableDuplexConnection.onClose().as(StepVerifier::create).expectError().verify(); + keepAliveSupport.dispose(); + transport.alloc().assertHasNoLeaks(); + } finally { + VirtualTimeScheduler.reset(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/ByteBufUtils.java b/rsocket-core/src/test/java/io/rsocket/test/util/ByteBufUtils.java new file mode 100644 index 000000000..9bed415ae --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/ByteBufUtils.java @@ -0,0 +1,32 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import java.util.concurrent.ThreadLocalRandom; + +public final class ByteBufUtils { + + private ByteBufUtils() {} + + public static ByteBuf getRandomByteBuf(int size) { + byte[] bytes = new byte[size]; + ThreadLocalRandom.current().nextBytes(bytes); + return Unpooled.wrappedBuffer(bytes); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java new file mode 100644 index 000000000..cdfcefdc8 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/LocalDuplexConnection.java @@ -0,0 +1,124 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import java.net.SocketAddress; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; + +public class LocalDuplexConnection implements DuplexConnection { + private final ByteBufAllocator allocator; + private final Sinks.Many send; + private final Sinks.Many receive; + private final Sinks.Empty onClose; + private final String name; + + public LocalDuplexConnection( + String name, + ByteBufAllocator allocator, + Sinks.Many send, + Sinks.Many receive) { + this.name = name; + this.allocator = allocator; + this.send = send; + this.receive = receive; + this.onClose = Sinks.empty(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + System.out.println(name + " - " + frame.toString()); + send.tryEmitNext(frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + System.out.println(name + " - " + errorFrame.toString()); + send.tryEmitNext(errorFrame); + onClose.tryEmitEmpty(); + } + + @Override + public Flux receive() { + return receive + .asFlux() + .doOnNext(f -> System.out.println(name + " - " + f.toString())) + .transform( + Operators.lift( + (__, actual) -> + new CoreSubscriber() { + + @Override + public void onSubscribe(Subscription s) { + actual.onSubscribe(s); + } + + @Override + public void onNext(ByteBuf byteBuf) { + actual.onNext(byteBuf); + byteBuf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + })); + } + + @Override + public ByteBufAllocator alloc() { + return allocator; + } + + @Override + public SocketAddress remoteAddress() { + return new TestLocalSocketAddress(name); + } + + @Override + public void dispose() { + onClose.tryEmitEmpty(); + } + + @Override + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); + } + + @Override + public Mono onClose() { + return onClose.asMono(); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java b/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java new file mode 100644 index 000000000..a33c4c4b3 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/MockRSocket.java @@ -0,0 +1,122 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test.util; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import java.util.concurrent.atomic.AtomicInteger; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class MockRSocket implements RSocket { + + private final AtomicInteger fnfCount; + private final AtomicInteger rrCount; + private final AtomicInteger rStreamCount; + private final AtomicInteger rSubCount; + private final AtomicInteger rChannelCount; + private final AtomicInteger pushCount; + private final RSocket delegate; + + public MockRSocket(RSocket delegate) { + this.delegate = delegate; + fnfCount = new AtomicInteger(); + rrCount = new AtomicInteger(); + rStreamCount = new AtomicInteger(); + rSubCount = new AtomicInteger(); + rChannelCount = new AtomicInteger(); + pushCount = new AtomicInteger(); + } + + @Override + public final Mono fireAndForget(Payload payload) { + return delegate.fireAndForget(payload).doOnSubscribe(s -> fnfCount.incrementAndGet()); + } + + @Override + public final Mono requestResponse(Payload payload) { + return delegate.requestResponse(payload).doOnSubscribe(s -> rrCount.incrementAndGet()); + } + + @Override + public final Flux requestStream(Payload payload) { + return delegate.requestStream(payload).doOnSubscribe(s -> rStreamCount.incrementAndGet()); + } + + @Override + public final Flux requestChannel(Publisher payloads) { + return delegate.requestChannel(payloads).doOnSubscribe(s -> rChannelCount.incrementAndGet()); + } + + @Override + public final Mono metadataPush(Payload payload) { + return delegate.metadataPush(payload).doOnSubscribe(s -> pushCount.incrementAndGet()); + } + + @Override + public double availability() { + return delegate.availability(); + } + + @Override + public void dispose() { + delegate.dispose(); + } + + @Override + public boolean isDisposed() { + return delegate.isDisposed(); + } + + @Override + public Mono onClose() { + return delegate.onClose(); + } + + public void assertFireAndForgetCount(int expected) { + assertCount(expected, "fire-and-forget", fnfCount); + } + + public void assertRequestResponseCount(int expected) { + assertCount(expected, "request-response", rrCount); + } + + public void assertRequestStreamCount(int expected) { + assertCount(expected, "request-stream", rStreamCount); + } + + public void assertRequestSubscriptionCount(int expected) { + assertCount(expected, "request-subscription", rSubCount); + } + + public void assertRequestChannelCount(int expected) { + assertCount(expected, "request-channel", rChannelCount); + } + + public void assertMetadataPushCount(int expected) { + assertCount(expected, "metadata-push", pushCount); + } + + private static void assertCount(int expected, String type, AtomicInteger counter) { + assertThat(counter.get()) + .describedAs("Unexpected invocations for " + type + '.') + .isEqualTo(expected); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/StringUtils.java b/rsocket-core/src/test/java/io/rsocket/test/util/StringUtils.java new file mode 100644 index 000000000..403eacb6d --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/StringUtils.java @@ -0,0 +1,34 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test.util; + +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Collectors; + +public final class StringUtils { + + private static final String CANDIDATE_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + + private StringUtils() {} + + public static String getRandomString(int size) { + return ThreadLocalRandom.current() + .ints(size, 0, CANDIDATE_CHARS.length()) + .mapToObj(index -> ((Character) CANDIDATE_CHARS.charAt(index)).toString()) + .collect(Collectors.joining()); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java new file mode 100644 index 000000000..f02bc99a4 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestClientTransport.java @@ -0,0 +1,43 @@ +package io.rsocket.test.util; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.transport.ClientTransport; +import java.time.Duration; +import reactor.core.publisher.Mono; + +public class TestClientTransport implements ClientTransport { + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofSeconds(1), "client"); + + private volatile TestDuplexConnection testDuplexConnection; + + int maxFrameLength = FRAME_LENGTH_MASK; + + @Override + public Mono connect() { + return Mono.fromSupplier(() -> testDuplexConnection = new TestDuplexConnection(allocator)); + } + + public TestDuplexConnection testConnection() { + return testDuplexConnection; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + public TestClientTransport withMaxFrameLength(int maxFrameLength) { + this.maxFrameLength = maxFrameLength; + return this; + } + + @Override + public int maxFrameLength() { + return maxFrameLength; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java new file mode 100644 index 000000000..8793d6ca4 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestDuplexConnection.java @@ -0,0 +1,194 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test.util; + +import io.netty.buffer.ByteBuf; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.frame.ErrorFrameCodec; +import java.net.SocketAddress; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.DirectProcessor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.FluxSink; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.util.annotation.NonNull; + +/** + * An implementation of {@link DuplexConnection} that provides functionality to modify the behavior + * dynamically. + */ +public class TestDuplexConnection implements DuplexConnection { + + private static final Logger logger = LoggerFactory.getLogger(TestDuplexConnection.class); + + private final LinkedBlockingQueue sent; + + private final DirectProcessor sentPublisher; + private final FluxSink sendSink; + private final DirectProcessor received; + private final FluxSink receivedSink; + private final MonoProcessor onClose; + private final LeaksTrackingByteBufAllocator allocator; + private volatile double availability = 1; + private volatile int initialSendRequestN = Integer.MAX_VALUE; + + public TestDuplexConnection(LeaksTrackingByteBufAllocator allocator) { + this.allocator = allocator; + this.sent = new LinkedBlockingQueue<>(); + this.received = DirectProcessor.create(); + this.receivedSink = received.sink(); + this.sentPublisher = DirectProcessor.create(); + this.sendSink = sentPublisher.sink(); + this.onClose = MonoProcessor.create(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + if (availability <= 0) { + throw new IllegalStateException("RSocket not available. Availability: " + availability); + } + + sendSink.next(frame); + sent.offer(frame); + } + + @Override + public Flux receive() { + return received.transform( + Operators.lift( + (__, actual) -> + new CoreSubscriber() { + @Override + public void onSubscribe(Subscription s) { + actual.onSubscribe(s); + } + + @Override + public void onNext(ByteBuf byteBuf) { + actual.onNext(byteBuf); + byteBuf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + })); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + sendSink.next(errorFrame); + sent.offer(errorFrame); + + final Throwable cause = e.getCause(); + if (cause == null) { + onClose.onComplete(); + } else { + onClose.onError(cause); + } + } + + @Override + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + @Override + public SocketAddress remoteAddress() { + return new TestLocalSocketAddress("TestDuplexConnection"); + } + + @Override + public double availability() { + return availability; + } + + @Override + public void dispose() { + onClose.onComplete(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; + } + + public boolean isEmpty() { + return sent.isEmpty(); + } + + @NonNull + public ByteBuf awaitFrame() { + try { + return sent.take(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + public ByteBuf pollFrame() { + return sent.poll(); + } + + public void setAvailability(double availability) { + this.availability = availability; + } + + public BlockingQueue getSent() { + return sent; + } + + public Publisher getSentAsPublisher() { + return sentPublisher; + } + + public void addToReceivedBuffer(ByteBuf... received) { + for (ByteBuf frame : received) { + this.receivedSink.next(frame); + } + } + + public void clearSendReceiveBuffers() { + sent.clear(); + } + + public void setInitialSendRequestN(int initialSendRequestN) { + this.initialSendRequestN = initialSendRequestN; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java new file mode 100644 index 000000000..2dad2cc1f --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestLocalSocketAddress.java @@ -0,0 +1,46 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test.util; + +import java.net.SocketAddress; +import java.util.Objects; + +public final class TestLocalSocketAddress extends SocketAddress { + + private static final long serialVersionUID = 2608695156052100164L; + + private final String name; + + /** + * Creates a new instance. + * + * @param name the name representing the address + * @throws NullPointerException if {@code name} is {@code null} + */ + public TestLocalSocketAddress(String name) { + this.name = Objects.requireNonNull(name, "name must not be null"); + } + + /** Return the name for this connection. */ + public String getName() { + return name; + } + + @Override + public String toString() { + return "[local address] " + name; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java new file mode 100644 index 000000000..fa9331d3b --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestServerTransport.java @@ -0,0 +1,90 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test.util; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.Closeable; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import io.rsocket.transport.ServerTransport; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +public class TestServerTransport implements ServerTransport { + private final Sinks.One connSink = Sinks.one(); + private TestDuplexConnection connection; + private final LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + + int maxFrameLength = FRAME_LENGTH_MASK; + + @Override + public Mono start(ConnectionAcceptor acceptor) { + connSink + .asMono() + .flatMap(duplexConnection -> acceptor.apply(duplexConnection)) + .subscribe(ignored -> {}, err -> disposeConnection(), this::disposeConnection); + return Mono.just( + new Closeable() { + @Override + public Mono onClose() { + return connSink.asMono().then(); + } + + @Override + public void dispose() { + connSink.tryEmitEmpty(); + } + + @Override + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return connSink.scan(Scannable.Attr.TERMINATED) + || connSink.scan(Scannable.Attr.CANCELLED); + } + }); + } + + private void disposeConnection() { + TestDuplexConnection c = connection; + if (c != null) { + c.dispose(); + } + } + + public TestDuplexConnection connect() { + TestDuplexConnection c = new TestDuplexConnection(allocator); + connection = c; + connSink.tryEmitValue(c); + return c; + } + + public LeaksTrackingByteBufAllocator alloc() { + return allocator; + } + + public TestServerTransport withMaxFrameLength(int maxFrameLength) { + this.maxFrameLength = maxFrameLength; + return this; + } + + @Override + public int maxFrameLength() { + return maxFrameLength; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/test/util/TestSubscriber.java b/rsocket-core/src/test/java/io/rsocket/test/util/TestSubscriber.java new file mode 100644 index 000000000..e88b39648 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/test/util/TestSubscriber.java @@ -0,0 +1,67 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test.util; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; + +import io.rsocket.Payload; +import org.mockito.Mockito; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +public class TestSubscriber { + public static Subscriber create() { + return create(Long.MAX_VALUE); + } + + public static Subscriber create(long initialRequest) { + @SuppressWarnings("unchecked") + Subscriber mock = mock(Subscriber.class); + + Mockito.doAnswer( + invocation -> { + if (initialRequest > 0) { + ((Subscription) invocation.getArguments()[0]).request(initialRequest); + } + return null; + }) + .when(mock) + .onSubscribe(any(Subscription.class)); + + return mock; + } + + public static Payload anyPayload() { + return any(Payload.class); + } + + public static Subscriber createCancelling() { + @SuppressWarnings("unchecked") + Subscriber mock = mock(Subscriber.class); + + Mockito.doAnswer( + invocation -> { + ((Subscription) invocation.getArguments()[0]).cancel(); + return null; + }) + .when(mock) + .onSubscribe(any(Subscription.class)); + + return mock; + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java new file mode 100644 index 000000000..2ad944d09 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/ByteBufPayloadTest.java @@ -0,0 +1,64 @@ +package io.rsocket.util; + +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; +import io.rsocket.Payload; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class ByteBufPayloadTest { + + @Test + public void shouldIndicateThatItHasMetadata() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = ByteBufPayload.create("data"); + + Assertions.assertThat(payload.hasMetadata()).isFalse(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + ByteBufPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + Assertions.assertThat(payload.hasMetadata()).isTrue(); + Assertions.assertThat(payload.release()).isTrue(); + } + + @Test + public void shouldThrowExceptionIfAccessAfterRelease() { + Payload payload = ByteBufPayload.create("data", "metadata"); + + Assertions.assertThat(payload.release()).isTrue(); + + Assertions.assertThatThrownBy(payload::hasMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::data).isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::metadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::sliceMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::touch) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(() -> payload.touch("test")) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getData) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadata) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getDataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + Assertions.assertThatThrownBy(payload::getMetadataUtf8) + .isInstanceOf(IllegalReferenceCountException.class); + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java new file mode 100644 index 000000000..f04de78b6 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/DefaultPayloadTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.buffer.LeaksTrackingByteBufAllocator; +import java.nio.ByteBuffer; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.Test; + +public class DefaultPayloadTest { + public static final String DATA_VAL = "data"; + public static final String METADATA_VAL = "metadata"; + + @Test + public void testReuse() { + Payload p = DefaultPayload.create(DATA_VAL, METADATA_VAL); + assertDataAndMetadata(p, DATA_VAL, METADATA_VAL); + assertDataAndMetadata(p, DATA_VAL, METADATA_VAL); + } + + public void assertDataAndMetadata(Payload p, String dataVal, String metadataVal) { + assertThat(p.getDataUtf8()).describedAs("Unexpected data.").isEqualTo(dataVal); + if (metadataVal == null) { + assertThat(p.hasMetadata()).describedAs("Non-null metadata").isEqualTo(false); + } else { + assertThat(p.hasMetadata()).describedAs("Null metadata").isEqualTo(true); + assertThat(p.getMetadataUtf8()).describedAs("Unexpected metadata.").isEqualTo(metadataVal); + } + } + + @Test + public void staticMethods() { + assertDataAndMetadata(DefaultPayload.create(DATA_VAL, METADATA_VAL), DATA_VAL, METADATA_VAL); + assertDataAndMetadata(DefaultPayload.create(DATA_VAL), DATA_VAL, null); + } + + @Test + public void shouldIndicateThatItHasNotMetadata() { + Payload payload = DefaultPayload.create("data"); + + assertThat(payload.hasMetadata()).isFalse(); + } + + @Test + public void shouldIndicateThatItHasMetadata1() { + Payload payload = + DefaultPayload.create(Unpooled.wrappedBuffer("data".getBytes()), Unpooled.EMPTY_BUFFER); + + assertThat(payload.hasMetadata()).isTrue(); + } + + @Test + public void shouldIndicateThatItHasMetadata2() { + Payload payload = + DefaultPayload.create(ByteBuffer.wrap("data".getBytes()), ByteBuffer.allocate(0)); + + assertThat(payload.hasMetadata()).isTrue(); + } + + @Test + public void shouldReleaseGivenByteBufDataAndMetadataUpOnPayloadCreation() { + LeaksTrackingByteBufAllocator allocator = + LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT); + for (byte i = 0; i < 126; i++) { + ByteBuf data = allocator.buffer(); + data.writeByte(i); + + boolean metadataPresent = ThreadLocalRandom.current().nextBoolean(); + ByteBuf metadata = null; + if (metadataPresent) { + metadata = allocator.buffer(); + metadata.writeByte(i + 1); + } + + Payload payload = DefaultPayload.create(data, metadata); + + assertThat(payload.getData()).isEqualTo(ByteBuffer.wrap(new byte[] {i})); + + assertThat(payload.getMetadata()) + .isEqualTo( + metadataPresent + ? ByteBuffer.wrap(new byte[] {(byte) (i + 1)}) + : DefaultPayload.EMPTY_BUFFER); + allocator.assertHasNoLeaks(); + } + } +} diff --git a/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java new file mode 100644 index 000000000..46e0f77f4 --- /dev/null +++ b/rsocket-core/src/test/java/io/rsocket/util/NumberUtilsTest.java @@ -0,0 +1,187 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.util; + +import static org.assertj.core.api.Assertions.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class NumberUtilsTest { + + @DisplayName("returns int value with postitive int") + @Test + void requireNonNegativeInt() { + assertThat(NumberUtils.requireNonNegative(Integer.MAX_VALUE, "test-message")) + .isEqualTo(Integer.MAX_VALUE); + } + + @DisplayName( + "requireNonNegative with int argument throws IllegalArgumentException with negative value") + @Test + void requireNonNegativeIntNegative() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requireNonNegative(Integer.MIN_VALUE, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requireNonNegative with int argument throws NullPointerException with null message") + @Test + void requireNonNegativeIntNullMessage() { + assertThatNullPointerException() + .isThrownBy(() -> NumberUtils.requireNonNegative(Integer.MIN_VALUE, null)) + .withMessage("message must not be null"); + } + + @DisplayName("requireNonNegative returns int value with zero") + @Test + void requireNonNegativeIntZero() { + assertThat(NumberUtils.requireNonNegative(0, "test-message")).isEqualTo(0); + } + + @DisplayName("requirePositive returns int value with positive int") + @Test + void requirePositiveInt() { + assertThat(NumberUtils.requirePositive(Integer.MAX_VALUE, "test-message")) + .isEqualTo(Integer.MAX_VALUE); + } + + @DisplayName( + "requirePositive with int argument throws IllegalArgumentException with negative value") + @Test + void requirePositiveIntNegative() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requirePositive(Integer.MIN_VALUE, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requirePositive with int argument throws NullPointerException with null message") + @Test + void requirePositiveIntNullMessage() { + assertThatNullPointerException() + .isThrownBy(() -> NumberUtils.requirePositive(Integer.MIN_VALUE, null)) + .withMessage("message must not be null"); + } + + @DisplayName("requirePositive with int argument throws IllegalArgumentException with zero value") + @Test + void requirePositiveIntZero() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requirePositive(0, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requirePositive returns long value with positive long") + @Test + void requirePositiveLong() { + assertThat(NumberUtils.requirePositive(Long.MAX_VALUE, "test-message")) + .isEqualTo(Long.MAX_VALUE); + } + + @DisplayName( + "requirePositive with long argument throws IllegalArgumentException with negative value") + @Test + void requirePositiveLongNegative() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requirePositive(Long.MIN_VALUE, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requirePositive with long argument throws NullPointerException with null message") + @Test + void requirePositiveLongNullMessage() { + assertThatNullPointerException() + .isThrownBy(() -> NumberUtils.requirePositive(Long.MIN_VALUE, null)) + .withMessage("message must not be null"); + } + + @DisplayName("requirePositive with long argument throws IllegalArgumentException with zero value") + @Test + void requirePositiveLongZero() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requirePositive(0L, "test-message")) + .withMessage("test-message"); + } + + @DisplayName("requireUnsignedByte returns length if 255") + @Test + void requireUnsignedByte() { + assertThat(NumberUtils.requireUnsignedByte((1 << 8) - 1)).isEqualTo(255); + } + + @DisplayName("requireUnsignedByte throws IllegalArgumentException if larger than 255") + @Test + void requireUnsignedByteOverFlow() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requireUnsignedByte(1 << 8)) + .withMessage("%d is larger than 8 bits", 1 << 8); + } + + @DisplayName("requireUnsignedMedium returns length if 16_777_215") + @Test + void requireUnsignedMedium() { + assertThat(NumberUtils.requireUnsignedMedium((1 << 24) - 1)).isEqualTo(16_777_215); + } + + @DisplayName("requireUnsignedMedium throws IllegalArgumentException if larger than 16_777_215") + @Test + void requireUnsignedMediumOverFlow() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requireUnsignedMedium(1 << 24)) + .withMessage("%d is larger than 24 bits", 1 << 24); + } + + @DisplayName("requireUnsignedShort returns length if 65_535") + @Test + void requireUnsignedShort() { + assertThat(NumberUtils.requireUnsignedShort((1 << 16) - 1)).isEqualTo(65_535); + } + + @DisplayName("requireUnsignedShort throws IllegalArgumentException if larger than 65_535") + @Test + void requireUnsignedShortOverFlow() { + assertThatIllegalArgumentException() + .isThrownBy(() -> NumberUtils.requireUnsignedShort(1 << 16)) + .withMessage("%d is larger than 16 bits", 1 << 16); + } + + @Test + void encodeUnsignedMedium() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + NumberUtils.encodeUnsignedMedium(buffer, 129); + buffer.markReaderIndex(); + + assertThat(buffer.readUnsignedMedium()).as("reading as unsigned medium").isEqualTo(129); + + buffer.resetReaderIndex(); + assertThat(buffer.readMedium()).as("reading as signed medium").isEqualTo(129); + } + + @Test + void encodeUnsignedMediumLarge() { + ByteBuf buffer = ByteBufAllocator.DEFAULT.buffer(); + NumberUtils.encodeUnsignedMedium(buffer, 0xFFFFFC); + buffer.markReaderIndex(); + + assertThat(buffer.readUnsignedMedium()).as("reading as unsigned medium").isEqualTo(16777212); + + buffer.resetReaderIndex(); + assertThat(buffer.readMedium()).as("reading as signed medium").isEqualTo(-4); + } +} diff --git a/rsocket-core/src/test/resources/META-INF/services/org.assertj.core.presentation.Representation b/rsocket-core/src/test/resources/META-INF/services/org.assertj.core.presentation.Representation new file mode 100644 index 000000000..9ac418a0c --- /dev/null +++ b/rsocket-core/src/test/resources/META-INF/services/org.assertj.core.presentation.Representation @@ -0,0 +1,16 @@ +# +# Copyright 2015-2018 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +io.rsocket.frame.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension b/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension new file mode 100644 index 000000000..2b51ba0de --- /dev/null +++ b/rsocket-core/src/test/resources/META-INF/services/org.junit.jupiter.api.extension.Extension @@ -0,0 +1 @@ +io.rsocket.frame.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-core/src/test/resources/logback-test.xml b/rsocket-core/src/test/resources/logback-test.xml new file mode 100644 index 000000000..9081698fb --- /dev/null +++ b/rsocket-core/src/test/resources/logback-test.xml @@ -0,0 +1,32 @@ + + + + + + + + %date{HH:mm:ss.SSS} %-10thread %-42logger %msg%n + + + + + + + + + + diff --git a/rsocket-examples/build.gradle b/rsocket-examples/build.gradle new file mode 100644 index 000000000..4059eb957 --- /dev/null +++ b/rsocket-examples/build.gradle @@ -0,0 +1,50 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java' +} + +dependencies { + implementation project(':rsocket-core') + implementation project(':rsocket-load-balancer') + implementation project(':rsocket-transport-local') + implementation project(':rsocket-transport-netty') + + implementation "io.micrometer:micrometer-core" + implementation "io.micrometer:micrometer-tracing" + implementation project(":rsocket-micrometer") + + implementation 'com.netflix.concurrency-limits:concurrency-limits-core' + implementation "io.micrometer:micrometer-core" + implementation "io.micrometer:micrometer-tracing" + implementation project(":rsocket-micrometer") + + runtimeOnly 'ch.qos.logback:logback-classic' + + testImplementation project(':rsocket-test') + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.mockito:mockito-core' + testImplementation 'org.assertj:assertj-core' + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.awaitility:awaitility' + testImplementation "io.micrometer:micrometer-test" + testImplementation "io.micrometer:micrometer-tracing-integration-test" + + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' +} + +description = 'Example usage of the RSocket library' diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java new file mode 100644 index 000000000..463043020 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/channel/ChannelEchoClient.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.channel; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public final class ChannelEchoClient { + + private static final Logger logger = LoggerFactory.getLogger(ChannelEchoClient.class); + + public static void main(String[] args) { + + SocketAcceptor echoAcceptor = + SocketAcceptor.forRequestChannel( + payloads -> + Flux.from(payloads) + .map(Payload::getDataUtf8) + .map(s -> "Echo: " + s) + .map(DefaultPayload::create)); + + RSocketServer.create(echoAcceptor).bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); + + socket + .requestChannel( + Flux.interval(Duration.ofMillis(1000)).map(i -> DefaultPayload.create("Hello"))) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .doFinally(signalType -> socket.dispose()) + .then() + .block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java new file mode 100644 index 000000000..dfbbcde53 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/client/RSocketClientExample.java @@ -0,0 +1,55 @@ +package io.rsocket.examples.transport.tcp.client; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; +import reactor.util.retry.Retry; + +public class RSocketClientExample { + static final Logger logger = LoggerFactory.getLogger(RSocketClientExample.class); + + public static void main(String[] args) { + + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + String data = p.getDataUtf8(); + logger.info("Received request data {}", data); + + Payload responsePayload = DefaultPayload.create("Echo: " + data); + p.release(); + + return Mono.just(responsePayload); + })) + .bind(TcpServerTransport.create("localhost", 7000)) + .delaySubscription(Duration.ofSeconds(5)) + .doOnNext(cc -> logger.info("Server started on the address : {}", cc.address())) + .block(); + + Mono source = + RSocketConnector.create() + .reconnect(Retry.backoff(50, Duration.ofMillis(500))) + .connect(TcpClientTransport.create("localhost", 7000)); + + RSocketClient.from(source) + .requestResponse(Mono.just(DefaultPayload.create("Test Request"))) + .doOnSubscribe(s -> logger.info("Executing Request")) + .doOnNext( + d -> { + logger.info("Received response data {}", d.getDataUtf8()); + d.release(); + }) + .repeat(10) + .blockLast(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java new file mode 100644 index 000000000..89b22749f --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/fnf/TaskProcessingWithServerSideNotificationsExample.java @@ -0,0 +1,237 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.examples.transport.tcp.fnf; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadLocalRandom; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.BaseSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.concurrent.Queues; + +/** + * An example of long-running tasks processing (a.k.a Kafka style) where a client submits tasks over + * request `FireAndForget` and then receives results over the same method but on it is own side. + * + *

This example shows a case when the client may disappear, however, another a client can connect + * again and receive undelivered completed tasks remaining for the previous one. + */ +public class TaskProcessingWithServerSideNotificationsExample { + + public static void main(String[] args) throws InterruptedException { + Sinks.Many tasksProcessor = + Sinks.many().unicast().onBackpressureBuffer(Queues.unboundedMultiproducer().get()); + ConcurrentMap> idToCompletedTasksMap = new ConcurrentHashMap<>(); + ConcurrentMap idToRSocketMap = new ConcurrentHashMap<>(); + BackgroundWorker backgroundWorker = + new BackgroundWorker(tasksProcessor.asFlux(), idToCompletedTasksMap, idToRSocketMap); + + RSocketServer.create(new TasksAcceptor(tasksProcessor, idToCompletedTasksMap, idToRSocketMap)) + .bindNow(TcpServerTransport.create(9991)); + + Logger logger = LoggerFactory.getLogger("RSocket.Client.ID[Test]"); + + Mono rSocketMono = + RSocketConnector.create() + .setupPayload(DefaultPayload.create("Test")) + .acceptor( + SocketAcceptor.forFireAndForget( + p -> { + logger.info("Received Processed Task[{}]", p.getDataUtf8()); + p.release(); + return Mono.empty(); + })) + .connect(TcpClientTransport.create(9991)); + + RSocket rSocketRequester1 = rSocketMono.block(); + + for (int i = 0; i < 10; i++) { + rSocketRequester1.fireAndForget(DefaultPayload.create("task" + i)).block(); + } + + Thread.sleep(4000); + + rSocketRequester1.dispose(); + logger.info("Disposed"); + + Thread.sleep(4000); + + RSocket rSocketRequester2 = rSocketMono.block(); + + logger.info("Reconnected"); + + Thread.sleep(10000); + } + + static class BackgroundWorker extends BaseSubscriber { + final ConcurrentMap> idToCompletedTasksMap; + final ConcurrentMap idToRSocketMap; + + BackgroundWorker( + Flux taskProducer, + ConcurrentMap> idToCompletedTasksMap, + ConcurrentMap idToRSocketMap) { + + this.idToCompletedTasksMap = idToCompletedTasksMap; + this.idToRSocketMap = idToRSocketMap; + + // mimic a long running task processing + taskProducer + .concatMap( + t -> + Mono.delay(Duration.ofMillis(ThreadLocalRandom.current().nextInt(200, 2000))) + .thenReturn(t)) + .subscribe(this); + } + + @Override + protected void hookOnNext(Task task) { + BlockingQueue completedTasksQueue = + idToCompletedTasksMap.computeIfAbsent(task.id, __ -> new LinkedBlockingQueue<>()); + + completedTasksQueue.offer(task); + RSocket rSocket = idToRSocketMap.get(task.id); + if (rSocket != null) { + rSocket + .fireAndForget(DefaultPayload.create(task.content)) + .subscribe(null, e -> {}, () -> completedTasksQueue.remove(task)); + } + } + } + + static class TasksAcceptor implements SocketAcceptor { + + static final Logger logger = LoggerFactory.getLogger(TasksAcceptor.class); + + final Sinks.Many tasksToProcess; + final ConcurrentMap> idToCompletedTasksMap; + final ConcurrentMap idToRSocketMap; + + TasksAcceptor( + Sinks.Many tasksToProcess, + ConcurrentMap> idToCompletedTasksMap, + ConcurrentMap idToRSocketMap) { + this.tasksToProcess = tasksToProcess; + this.idToCompletedTasksMap = idToCompletedTasksMap; + this.idToRSocketMap = idToRSocketMap; + } + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + String id = setup.getDataUtf8(); + logger.info("Accepting a new client connection with ID {}", id); + // sendingRSocket represents here an RSocket requester to a remote peer + + if (this.idToRSocketMap.compute( + id, (__, old) -> old == null || old.isDisposed() ? sendingSocket : old) + == sendingSocket) { + return Mono.just( + new RSocketTaskHandler(idToRSocketMap, tasksToProcess, id, sendingSocket)) + .doOnSuccess(__ -> checkTasksToDeliver(sendingSocket, id)); + } + + return Mono.error( + new IllegalStateException("There is already a client connected with the same ID")); + } + + private void checkTasksToDeliver(RSocket sendingSocket, String id) { + logger.info("Accepted a new client connection with ID {}. Checking for remaining tasks", id); + BlockingQueue tasksToDeliver = this.idToCompletedTasksMap.get(id); + + if (tasksToDeliver == null || tasksToDeliver.isEmpty()) { + // means nothing yet to send + return; + } + + logger.info("Found remaining tasks to deliver for client {}", id); + + for (; ; ) { + Task task = tasksToDeliver.poll(); + + if (task == null) { + return; + } + + sendingSocket + .fireAndForget(DefaultPayload.create(task.content)) + .subscribe( + null, + e -> { + // offers back a task if it has not been delivered + tasksToDeliver.offer(task); + }); + } + } + + private static class RSocketTaskHandler implements RSocket { + + private final String id; + private final RSocket sendingSocket; + private ConcurrentMap idToRSocketMap; + private Sinks.Many tasksToProcess; + + public RSocketTaskHandler( + ConcurrentMap idToRSocketMap, + Sinks.Many tasksToProcess, + String id, + RSocket sendingSocket) { + this.id = id; + this.sendingSocket = sendingSocket; + this.idToRSocketMap = idToRSocketMap; + this.tasksToProcess = tasksToProcess; + } + + @Override + public Mono fireAndForget(Payload payload) { + logger.info("Received a Task[{}] from Client.ID[{}]", payload.getDataUtf8(), id); + Sinks.EmitResult result = tasksToProcess.tryEmitNext(new Task(id, payload.getDataUtf8())); + payload.release(); + return result.isFailure() ? Mono.error(new Sinks.EmissionException(result)) : Mono.empty(); + } + + @Override + public void dispose() { + idToRSocketMap.remove(id, sendingSocket); + } + } + } + + static class Task { + final String id; + final String content; + + Task(String id, String content) { + this.id = id; + this.content = content; + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java new file mode 100644 index 000000000..272caf7a0 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LeaseManager.java @@ -0,0 +1,144 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import java.util.concurrent.BlockingDeque; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class LeaseManager implements Runnable { + + static final Logger logger = LoggerFactory.getLogger(LeaseManager.class); + + volatile int activeConnectionsCount; + static final AtomicIntegerFieldUpdater ACTIVE_CONNECTIONS_COUNT = + AtomicIntegerFieldUpdater.newUpdater(LeaseManager.class, "activeConnectionsCount"); + + volatile int stateAndInFlight; + static final AtomicIntegerFieldUpdater STATE_AND_IN_FLIGHT = + AtomicIntegerFieldUpdater.newUpdater(LeaseManager.class, "stateAndInFlight"); + + static final int MASK_PAUSED = 0b1_000_0000_0000_0000_0000_0000_0000_0000; + static final int MASK_IN_FLIGHT = 0b0_111_1111_1111_1111_1111_1111_1111_1111; + + final BlockingDeque sendersQueue = new LinkedBlockingDeque<>(); + final Scheduler worker = Schedulers.newSingle(LeaseManager.class.getName()); + + final int capacity; + final int ttl; + + public LeaseManager(int capacity, int ttl) { + this.capacity = capacity; + this.ttl = ttl; + } + + @Override + public void run() { + try { + LimitBasedLeaseSender leaseSender = sendersQueue.poll(); + + if (leaseSender == null) { + return; + } + + if (leaseSender.isDisposed()) { + logger.debug("Connection[" + leaseSender.connectionId + "]: LeaseSender is Disposed"); + worker.schedule(this); + return; + } + + int limit = leaseSender.limitAlgorithm.getLimit(); + + if (limit == 0) { + throw new IllegalStateException("Limit is 0"); + } + + if (pauseIfNoCapacity()) { + sendersQueue.addFirst(leaseSender); + logger.debug("Pause execution. Not enough capacity"); + return; + } + + leaseSender.sendLease(ttl, limit); + sendersQueue.offer(leaseSender); + + int activeConnections = activeConnectionsCount; + int nextDelay = activeConnections == 0 ? ttl : (ttl / activeConnections); + + logger.debug("Next check happens in " + nextDelay + "ms"); + + worker.schedule(this, nextDelay, TimeUnit.MILLISECONDS); + } catch (Throwable e) { + logger.error("LeaseSender failed to send lease", e); + } + } + + int incrementInFlightAndGet() { + for (; ; ) { + int state = stateAndInFlight; + int paused = state & MASK_PAUSED; + int inFlight = stateAndInFlight & MASK_IN_FLIGHT; + + // assume overflow is impossible due to max concurrency in RSocket it self + int nextInFlight = inFlight + 1; + + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight | paused)) { + return nextInFlight; + } + } + } + + void decrementInFlight() { + for (; ; ) { + int state = stateAndInFlight; + int paused = state & MASK_PAUSED; + int inFlight = stateAndInFlight & MASK_IN_FLIGHT; + + // assume overflow is impossible due to max concurrency in RSocket it self + int nextInFlight = inFlight - 1; + + if (inFlight == capacity && paused == MASK_PAUSED) { + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight)) { + logger.debug("Resume execution"); + worker.schedule(this); + return; + } + } else { + if (STATE_AND_IN_FLIGHT.compareAndSet(this, state, nextInFlight | paused)) { + return; + } + } + } + } + + boolean pauseIfNoCapacity() { + int capacity = this.capacity; + for (; ; ) { + int inFlight = stateAndInFlight; + + if (inFlight < capacity) { + return false; + } + + if (STATE_AND_IN_FLIGHT.compareAndSet(this, inFlight, inFlight | MASK_PAUSED)) { + return true; + } + } + } + + void unregister() { + ACTIVE_CONNECTIONS_COUNT.decrementAndGet(this); + } + + void register(LimitBasedLeaseSender sender) { + sendersQueue.offer(sender); + final int activeCount = ACTIVE_CONNECTIONS_COUNT.getAndIncrement(this); + + if (activeCount == 0) { + worker.schedule(this); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java new file mode 100644 index 000000000..8e1b27823 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedLeaseSender.java @@ -0,0 +1,54 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import com.netflix.concurrency.limits.Limit; +import io.rsocket.lease.Lease; +import io.rsocket.lease.TrackingLeaseSender; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; +import reactor.util.concurrent.Queues; + +public class LimitBasedLeaseSender extends LimitBasedStatsCollector implements TrackingLeaseSender { + + static final Logger logger = LoggerFactory.getLogger(LimitBasedLeaseSender.class); + + final String connectionId; + final Sinks.Many sink = + Sinks.many().unicast().onBackpressureBuffer(Queues.one().get()); + + public LimitBasedLeaseSender( + String connectionId, LeaseManager leaseManager, Limit limitAlgorithm) { + super(leaseManager, limitAlgorithm); + this.connectionId = connectionId; + } + + @Override + public Flux send() { + logger.info("Received new leased Connection[" + connectionId + "]"); + + leaseManager.register(this); + + return sink.asFlux(); + } + + public void sendLease(int ttl, int amount) { + final Lease nextLease = Lease.create(Duration.ofMillis(ttl), amount); + final Sinks.EmitResult result = sink.tryEmitNext(nextLease); + + if (result.isFailure()) { + logger.warn( + "Connection[" + + connectionId + + "]. Issued Lease: [" + + nextLease + + "] was not sent due to " + + result); + } else { + if (logger.isDebugEnabled()) { + logger.debug("To Connection[" + connectionId + "]: Issued Lease: [" + nextLease + "]"); + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java new file mode 100644 index 000000000..7f639ab87 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/common/LimitBasedStatsCollector.java @@ -0,0 +1,73 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.common; + +import com.netflix.concurrency.limits.Limit; +import io.netty.buffer.ByteBuf; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.RequestInterceptor; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.LongSupplier; +import reactor.util.annotation.Nullable; + +public class LimitBasedStatsCollector extends AtomicBoolean implements RequestInterceptor { + + final LeaseManager leaseManager; + final Limit limitAlgorithm; + + final ConcurrentMap inFlightMap = new ConcurrentHashMap<>(); + final ConcurrentMap timeMap = new ConcurrentHashMap<>(); + + final LongSupplier clock = System::nanoTime; + + public LimitBasedStatsCollector(LeaseManager leaseManager, Limit limitAlgorithm) { + this.leaseManager = leaseManager; + this.limitAlgorithm = limitAlgorithm; + } + + @Override + public void onStart(int streamId, FrameType requestType, @Nullable ByteBuf metadata) { + long startTime = clock.getAsLong(); + + int currentInFlight = leaseManager.incrementInFlightAndGet(); + + inFlightMap.put(streamId, currentInFlight); + timeMap.put(streamId, startTime); + } + + @Override + public void onReject( + Throwable rejectionReason, FrameType requestType, @Nullable ByteBuf metadata) {} + + @Override + public void onTerminate(int streamId, FrameType requestType, @Nullable Throwable t) { + leaseManager.decrementInFlight(); + + Long startTime = timeMap.remove(streamId); + Integer currentInflight = inFlightMap.remove(streamId); + + limitAlgorithm.onSample(startTime, clock.getAsLong() - startTime, currentInflight, t != null); + } + + @Override + public void onCancel(int streamId, FrameType requestType) { + leaseManager.decrementInFlight(); + + Long startTime = timeMap.remove(streamId); + Integer currentInflight = inFlightMap.remove(streamId); + + limitAlgorithm.onSample(startTime, clock.getAsLong() - startTime, currentInflight, true); + } + + @Override + public boolean isDisposed() { + return get(); + } + + @Override + public void dispose() { + if (!getAndSet(true)) { + leaseManager.unregister(); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java new file mode 100644 index 000000000..a18dd9484 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/Task.java @@ -0,0 +1,27 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.controller; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +// emulating a worker that process data from the queue +public class Task implements Runnable { + private static final Logger logger = LoggerFactory.getLogger(Task.class); + + final String message; + final int processingTime; + + Task(String message, int processingTime) { + this.message = message; + this.processingTime = processingTime; + } + + @Override + public void run() { + logger.info("Processing Task[{}]", message); + try { + Thread.sleep(processingTime); // emulating processing + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java new file mode 100644 index 000000000..cbecadfc3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/controller/TasksHandlingRSocket.java @@ -0,0 +1,44 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.controller; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Scheduler; + +public class TasksHandlingRSocket implements RSocket { + + private static final Logger logger = LoggerFactory.getLogger(TasksHandlingRSocket.class); + + final Disposable terminatable; + final Scheduler workScheduler; + final int processingTime; + + public TasksHandlingRSocket(Disposable terminatable, Scheduler scheduler, int processingTime) { + this.terminatable = terminatable; + this.workScheduler = scheduler; + this.processingTime = processingTime; + } + + @Override + public Mono fireAndForget(Payload payload) { + + // specifically to show that lease can limit rate of fnf requests in + // that example + String message = payload.getDataUtf8(); + payload.release(); + + return Mono.fromRunnable(new Task(message, processingTime)) + // schedule task on specific, limited in size scheduler + .subscribeOn(workScheduler) + // if errors - terminates server + .doOnError( + t -> { + logger.error("Queue has been overflowed. Terminating server"); + terminatable.dispose(); + System.exit(9); + }); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/README.MD b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/README.MD new file mode 100644 index 000000000..e69de29bb diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java new file mode 100644 index 000000000..30eb0c0e3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RequestingServer.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.advanced.invertmulticlient; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Comparator; +import java.util.concurrent.PriorityBlockingQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class RequestingServer { + + private static final Logger logger = LoggerFactory.getLogger(RequestingServer.class); + + public static void main(String[] args) { + PriorityBlockingQueue rSockets = + new PriorityBlockingQueue<>( + 16, Comparator.comparingDouble(RSocket::availability).reversed()); + + CloseableChannel server = + RSocketServer.create( + (setup, sendingSocket) -> { + logger.info("Received new connection"); + return Mono.just(new RSocket() {}) + .doAfterTerminate(() -> rSockets.put(sendingSocket)); + }) + .lease(spec -> spec.maxPendingRequests(Integer.MAX_VALUE)) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + logger.info("Server started on port {}", server.address().getPort()); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + .flatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + + return Mono.fromCallable( + () -> { + RSocket rSocket = rSockets.take(); + rSockets.offer(rSocket); + return rSocket; + }) + .flatMap( + clientRSocket -> + clientRSocket.fireAndForget(ByteBufPayload.create("" + tick))) + .retry(); + }) + .blockLast(); + + server.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java new file mode 100644 index 000000000..4a06855b2 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/invertmulticlient/RespondingClient.java @@ -0,0 +1,67 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.invertmulticlient; + +import com.netflix.concurrency.limits.limit.VegasLimit; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LeaseManager; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LimitBasedLeaseSender; +import io.rsocket.examples.transport.tcp.lease.advanced.controller.TasksHandlingRSocket; +import io.rsocket.transport.netty.client.TcpClientTransport; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class RespondingClient { + private static final Logger logger = LoggerFactory.getLogger(RespondingClient.class); + + public static final int PROCESSING_TASK_TIME = 500; + public static final int CONCURRENT_WORKERS_COUNT = 1; + public static final int QUEUE_CAPACITY = 50; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + BlockingQueue tasksQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); + + ThreadPoolExecutor threadPoolExecutor = + new ThreadPoolExecutor(1, CONCURRENT_WORKERS_COUNT, 1, TimeUnit.MINUTES, tasksQueue); + + Scheduler workScheduler = Schedulers.fromExecutorService(threadPoolExecutor); + + LeaseManager periodicLeaseSender = + new LeaseManager(CONCURRENT_WORKERS_COUNT, PROCESSING_TASK_TIME); + + Disposable.Composite disposable = Disposables.composite(); + RSocket clientRSocket = + RSocketConnector.create() + .acceptor( + SocketAcceptor.with( + new TasksHandlingRSocket(disposable, workScheduler, PROCESSING_TASK_TIME))) + .lease( + (config) -> + config.sender( + new LimitBasedLeaseSender( + UUID.randomUUID().toString(), + periodicLeaseSender, + VegasLimit.newBuilder() + .initialLimit(CONCURRENT_WORKERS_COUNT) + .maxConcurrency(QUEUE_CAPACITY) + .build()))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + Objects.requireNonNull(clientRSocket); + disposable.add(clientRSocket); + clientRSocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/README.MD b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/README.MD new file mode 100644 index 000000000..e69de29bb diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java new file mode 100644 index 000000000..c2fde38e3 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RequestingClient.java @@ -0,0 +1,41 @@ +package io.rsocket.examples.transport.tcp.lease.advanced.multiclient; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public class RequestingClient { + private static final Logger logger = LoggerFactory.getLogger(RequestingClient.class); + + public static void main(String[] args) { + + RSocket clientRSocket = + RSocketConnector.create() + .lease() + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + Objects.requireNonNull(clientRSocket); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + .concatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + return clientRSocket.fireAndForget(ByteBufPayload.create("" + tick)); + }) + .blockLast(); + + clientRSocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java new file mode 100644 index 000000000..b54330450 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/advanced/multiclient/RespondingServer.java @@ -0,0 +1,81 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.advanced.multiclient; + +import com.netflix.concurrency.limits.limit.VegasLimit; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketServer; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LeaseManager; +import io.rsocket.examples.transport.tcp.lease.advanced.common.LimitBasedLeaseSender; +import io.rsocket.examples.transport.tcp.lease.advanced.controller.TasksHandlingRSocket; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.util.UUID; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; + +public class RespondingServer { + + private static final Logger logger = LoggerFactory.getLogger(RespondingServer.class); + + public static final int TASK_PROCESSING_TIME = 500; + public static final int CONCURRENT_WORKERS_COUNT = 1; + public static final int QUEUE_CAPACITY = 50; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + BlockingQueue tasksQueue = new ArrayBlockingQueue<>(QUEUE_CAPACITY); + + ThreadPoolExecutor threadPoolExecutor = + new ThreadPoolExecutor(1, CONCURRENT_WORKERS_COUNT, 1, TimeUnit.MINUTES, tasksQueue); + + Scheduler workScheduler = Schedulers.fromExecutorService(threadPoolExecutor); + + LeaseManager leaseManager = new LeaseManager(CONCURRENT_WORKERS_COUNT, TASK_PROCESSING_TIME); + + Disposable.Composite disposable = Disposables.composite(); + CloseableChannel server = + RSocketServer.create( + SocketAcceptor.with( + new TasksHandlingRSocket(disposable, workScheduler, TASK_PROCESSING_TIME))) + .lease( + (config) -> + config.sender( + new LimitBasedLeaseSender( + UUID.randomUUID().toString(), + leaseManager, + VegasLimit.newBuilder() + .initialLimit(CONCURRENT_WORKERS_COUNT) + .maxConcurrency(QUEUE_CAPACITY) + .build()))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + disposable.add(server); + + logger.info("Server started on port {}", server.address().getPort()); + server.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java new file mode 100644 index 000000000..c54335ccc --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/lease/simple/LeaseExample.java @@ -0,0 +1,160 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.lease.simple; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.lease.Lease; +import io.rsocket.lease.LeaseSender; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class LeaseExample { + + private static final Logger logger = LoggerFactory.getLogger(LeaseExample.class); + + private static final String SERVER_TAG = "server"; + private static final String CLIENT_TAG = "client"; + + public static void main(String[] args) { + // Queue for incoming messages represented as Flux + // Imagine that every fireAndForget that is pushed is processed by a worker + + int queueCapacity = 50; + BlockingQueue messagesQueue = new ArrayBlockingQueue<>(queueCapacity); + + // emulating a worker that process data from the queue + Thread workerThread = + new Thread( + () -> { + try { + while (!Thread.currentThread().isInterrupted()) { + String message = messagesQueue.take(); + logger.info("Process message {}", message); + Thread.sleep(500); // emulating processing + } + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + + workerThread.start(); + + CloseableChannel server = + RSocketServer.create( + (setup, sendingSocket) -> + Mono.just( + new RSocket() { + @Override + public Mono fireAndForget(Payload payload) { + // add element. if overflows errors and terminates execution + // specifically to show that lease can limit rate of fnf requests in + // that example + try { + if (!messagesQueue.offer(payload.getDataUtf8())) { + logger.error("Queue has been overflowed. Terminating execution"); + sendingSocket.dispose(); + workerThread.interrupt(); + } + } finally { + payload.release(); + } + return Mono.empty(); + } + })) + .lease(leases -> leases.sender(new LeaseCalculator(SERVER_TAG, messagesQueue))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket clientRSocket = + RSocketConnector.create() + .lease((config) -> config.maxPendingRequests(1)) + .connect(TcpClientTransport.create(server.address())) + .block(); + + Objects.requireNonNull(clientRSocket); + + // generate stream of fnfs + Flux.generate( + () -> 0L, + (state, sink) -> { + sink.next(state); + return state + 1; + }) + // here we wait for the first lease for the responder side and start execution + // on if there is allowance + .concatMap( + tick -> { + logger.info("Requesting FireAndForget({})", tick); + return clientRSocket.fireAndForget(ByteBufPayload.create("" + tick)); + }) + .blockLast(); + + clientRSocket.onClose().block(); + server.dispose(); + } + + /** + * This is a class responsible for making decision on whether Responder is ready to receive new + * FireAndForget or not base in the number of messages enqueued.
+ * In the nutshell this is responder-side rate-limiter logic which is created for every new + * connection.
+ * In real-world projects this class has to issue leases based on real metrics + */ + private static class LeaseCalculator implements LeaseSender { + final String tag; + final BlockingQueue queue; + + public LeaseCalculator(String tag, BlockingQueue queue) { + this.tag = tag; + this.queue = queue; + } + + @Override + public Flux send() { + Duration ttlDuration = Duration.ofSeconds(5); + // The interval function is used only for the demo purpose and should not be + // considered as the way to issue leases. + // For advanced RateLimiting with Leasing + // consider adopting https://github.com/Netflix/concurrency-limits#server-limiter + return Flux.interval(Duration.ZERO, ttlDuration.dividedBy(2)) + .handle( + (__, sink) -> { + // put queue.remainingCapacity() + 1 here if you want to observe that app is + // terminated because of the queue overflowing + int requests = queue.remainingCapacity(); + + // reissue new lease only if queue has remaining capacity to + // accept more requests + if (requests > 0) { + sink.next(Lease.create(ttlDuration, requests)); + } + }); + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java new file mode 100644 index 000000000..abed4a52d --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/loadbalancer/RoundRobinRSocketLoadbalancerExample.java @@ -0,0 +1,110 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.examples.transport.tcp.loadbalancer; + +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketServer; +import io.rsocket.loadbalance.LoadbalanceRSocketClient; +import io.rsocket.loadbalance.LoadbalanceTarget; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class RoundRobinRSocketLoadbalancerExample { + + public static void main(String[] args) { + CloseableChannel server1 = + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + System.out.println("Server 1 got fnf " + p.getDataUtf8()); + return Mono.just(DefaultPayload.create("Server 1 response")) + .delayElement(Duration.ofMillis(100)); + })) + .bindNow(TcpServerTransport.create(8080)); + + CloseableChannel server2 = + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + System.out.println("Server 2 got fnf " + p.getDataUtf8()); + return Mono.just(DefaultPayload.create("Server 2 response")) + .delayElement(Duration.ofMillis(100)); + })) + .bindNow(TcpServerTransport.create(8081)); + + CloseableChannel server3 = + RSocketServer.create( + SocketAcceptor.forRequestResponse( + p -> { + System.out.println("Server 3 got fnf " + p.getDataUtf8()); + return Mono.just(DefaultPayload.create("Server 3 response")) + .delayElement(Duration.ofMillis(100)); + })) + .bindNow(TcpServerTransport.create(8082)); + + LoadbalanceTarget target8080 = LoadbalanceTarget.from("8080", TcpClientTransport.create(8080)); + LoadbalanceTarget target8081 = LoadbalanceTarget.from("8081", TcpClientTransport.create(8081)); + LoadbalanceTarget target8082 = LoadbalanceTarget.from("8082", TcpClientTransport.create(8082)); + + Flux> producer = + Flux.interval(Duration.ofSeconds(5)) + .log() + .map( + i -> { + int val = i.intValue(); + switch (val) { + case 0: + return Collections.emptyList(); + case 1: + return Collections.singletonList(target8080); + case 2: + return Arrays.asList(target8080, target8081); + case 3: + return Arrays.asList(target8080, target8082); + case 4: + return Arrays.asList(target8081, target8082); + case 5: + return Arrays.asList(target8080, target8081, target8082); + case 6: + return Collections.emptyList(); + case 7: + return Collections.emptyList(); + default: + return Arrays.asList(target8080, target8081, target8082); + } + }); + + RSocketClient rSocketClient = + LoadbalanceRSocketClient.builder(producer).roundRobinLoadbalanceStrategy().build(); + + for (int i = 0; i < 10000; i++) { + try { + rSocketClient.requestResponse(Mono.just(DefaultPayload.create("test" + i))).block(); + } catch (Throwable t) { + // no ops + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java new file mode 100644 index 000000000..a0a02a946 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/CompositeMetadataExample.java @@ -0,0 +1,102 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.metadata.routing; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.CompositeMetadataCodec; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TaggingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Collections; +import java.util.Objects; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public class CompositeMetadataExample { + static final Logger logger = LoggerFactory.getLogger(CompositeMetadataExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.forRequestResponse( + payload -> { + final String route = decodeRoute(payload.sliceMetadata()); + + logger.info("Received RequestResponse[route={}]", route); + + payload.release(); + + if ("my.test.route".equals(route)) { + return Mono.just(ByteBufPayload.create("Hello From My Test Route")); + } + + return Mono.error(new IllegalArgumentException("Route " + route + " not found")); + })) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + // here we specify that every metadata payload will be encoded using + // CompositeMetadata layout as specified in the following subspec + // https://github.com/rsocket/rsocket/blob/master/Extensions/CompositeMetadata.md + .metadataMimeType(WellKnownMimeType.MESSAGE_RSOCKET_COMPOSITE_METADATA.getString()) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final ByteBuf routeMetadata = + TaggingMetadataCodec.createTaggingContent( + ByteBufAllocator.DEFAULT, Collections.singletonList("my.test.route")); + final CompositeByteBuf compositeMetadata = ByteBufAllocator.DEFAULT.compositeBuffer(); + + CompositeMetadataCodec.encodeAndAddMetadata( + compositeMetadata, + ByteBufAllocator.DEFAULT, + WellKnownMimeType.MESSAGE_RSOCKET_ROUTING, + routeMetadata); + + socket + .requestResponse( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "HelloWorld"), compositeMetadata)) + .log() + .block(); + } + + static String decodeRoute(ByteBuf metadata) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false); + + for (CompositeMetadata.Entry metadatum : compositeMetadata) { + if (Objects.requireNonNull(metadatum.getMimeType()) + .equals(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString())) { + return new RoutingMetadata(metadatum.getContent()).iterator().next(); + } + } + + return null; + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java new file mode 100644 index 000000000..2aee18bf9 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/metadata/routing/RoutingMetadataExample.java @@ -0,0 +1,83 @@ +/* + * Copyright 2015-Present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.metadata.routing; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TaggingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.util.Collections; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public class RoutingMetadataExample { + static final Logger logger = LoggerFactory.getLogger(RoutingMetadataExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.forRequestResponse( + payload -> { + final String route = decodeRoute(payload.sliceMetadata()); + + logger.info("Received RequestResponse[route={}]", route); + + payload.release(); + + if ("my.test.route".equals(route)) { + return Mono.just(ByteBufPayload.create("Hello From My Test Route")); + } + + return Mono.error(new IllegalArgumentException("Route " + route + " not found")); + })) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + // here we specify that route will be encoded using + // Routing&Tagging Metadata layout specified at this + // subspec https://github.com/rsocket/rsocket/blob/master/Extensions/Routing.md + .metadataMimeType(WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final ByteBuf routeMetadata = + TaggingMetadataCodec.createTaggingContent( + ByteBufAllocator.DEFAULT, Collections.singletonList("my.test.route")); + socket + .requestResponse( + ByteBufPayload.create( + ByteBufUtil.writeUtf8(ByteBufAllocator.DEFAULT, "HelloWorld"), routeMetadata)) + .log() + .block(); + } + + static String decodeRoute(ByteBuf metadata) { + final RoutingMetadata routingMetadata = new RoutingMetadata(metadata); + + return routingMetadata.iterator().next(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java new file mode 100644 index 000000000..5491a1aab --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/plugins/LimitRateInterceptorExample.java @@ -0,0 +1,83 @@ +package io.rsocket.examples.transport.tcp.plugins; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.plugins.LimitRateInterceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public class LimitRateInterceptorExample { + + private static final Logger logger = LoggerFactory.getLogger(LimitRateInterceptorExample.class); + + public static void main(String[] args) { + RSocketServer.create( + SocketAcceptor.with( + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return Flux.interval(Duration.ofMillis(100)) + .doOnRequest( + e -> logger.debug("Server publisher receives request for " + e)) + .map(aLong -> DefaultPayload.create("Interval: " + aLong)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .doOnRequest( + e -> logger.debug("Server publisher receives request for " + e)); + } + })) + .interceptors(registry -> registry.forResponder(LimitRateInterceptor.forResponder(64))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + .interceptors(registry -> registry.forRequester(LimitRateInterceptor.forRequester(64))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + logger.debug( + "\n\nStart of requestStream interaction\n" + "----------------------------------\n"); + + socket + .requestStream(DefaultPayload.create("Hello")) + .doOnRequest(e -> logger.debug("Client sends requestN(" + e + ")")) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .block(); + + logger.debug( + "\n\nStart of requestChannel interaction\n" + "-----------------------------------\n"); + + socket + .requestChannel( + Flux.generate( + () -> 1L, + (s, sink) -> { + sink.next(DefaultPayload.create("Next " + s)); + return ++s; + }) + .doOnRequest(e -> logger.debug("Client publisher receives request for " + e))) + .doOnRequest(e -> logger.debug("Client sends requestN(" + e + ")")) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .doFinally(signalType -> socket.dispose()) + .then() + .block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java new file mode 100644 index 000000000..0c372d2d8 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/requestresponse/HelloWorldClient.java @@ -0,0 +1,69 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.requestresponse; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Mono; + +public final class HelloWorldClient { + + private static final Logger logger = LoggerFactory.getLogger(HelloWorldClient.class); + + public static void main(String[] args) { + + RSocket rsocket = + new RSocket() { + boolean fail = true; + + @Override + public Mono requestResponse(Payload p) { + if (fail) { + fail = false; + return Mono.error(new Throwable("Simulated error")); + } else { + return Mono.just(p); + } + } + }; + + RSocketServer.create(SocketAcceptor.with(rsocket)) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.connectWith(TcpClientTransport.create("localhost", 7000)).block(); + + for (int i = 0; i < 3; i++) { + socket + .requestResponse(DefaultPayload.create("Hello")) + .map(Payload::getDataUtf8) + .onErrorReturn("error") + .doOnNext(logger::debug) + .block(); + } + + socket.dispose(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java new file mode 100644 index 000000000..6724ca93f --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/Files.java @@ -0,0 +1,141 @@ +package io.rsocket.examples.transport.tcp.resume; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import java.io.BufferedInputStream; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.SynchronousSink; + +class Files { + private static final Logger logger = LoggerFactory.getLogger(Files.class); + + public static Flux fileSource(String fileName, int chunkSizeBytes) { + return Flux.generate( + () -> new FileState(fileName, chunkSizeBytes), FileState::consumeNext, FileState::dispose); + } + + public static Subscriber fileSink(String fileName, int windowSize) { + return new Subscriber() { + Subscription s; + int requests = windowSize; + OutputStream outputStream; + int receivedBytes; + int receivedCount; + + @Override + public void onSubscribe(Subscription s) { + this.s = s; + this.s.request(requests); + } + + @Override + public void onNext(Payload payload) { + ByteBuf data = payload.data(); + receivedBytes += data.readableBytes(); + receivedCount += 1; + logger.debug("Received file chunk: " + receivedCount + ". Total size: " + receivedBytes); + if (outputStream == null) { + outputStream = open(fileName); + } + write(outputStream, data); + payload.release(); + + requests--; + if (requests == windowSize / 2) { + requests += windowSize; + s.request(windowSize); + } + } + + private void write(OutputStream outputStream, ByteBuf byteBuf) { + try { + byteBuf.readBytes(outputStream, byteBuf.readableBytes()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public void onError(Throwable t) { + close(outputStream); + } + + @Override + public void onComplete() { + close(outputStream); + } + + private OutputStream open(String filename) { + try { + /*do not buffer for demo purposes*/ + return new FileOutputStream(filename); + } catch (FileNotFoundException e) { + throw new RuntimeException(e); + } + } + + private void close(OutputStream stream) { + if (stream != null) { + try { + stream.close(); + } catch (IOException e) { + } + } + } + }; + } + + private static class FileState { + private final String fileName; + private final int chunkSizeBytes; + private BufferedInputStream inputStream; + private byte[] chunkBytes; + + public FileState(String fileName, int chunkSizeBytes) { + this.fileName = fileName; + this.chunkSizeBytes = chunkSizeBytes; + } + + public FileState consumeNext(SynchronousSink sink) { + if (inputStream == null) { + InputStream in = getClass().getClassLoader().getResourceAsStream(fileName); + if (in == null) { + sink.error(new FileNotFoundException(fileName)); + return this; + } + this.inputStream = new BufferedInputStream(in); + this.chunkBytes = new byte[chunkSizeBytes]; + } + try { + int consumedBytes = inputStream.read(chunkBytes); + if (consumedBytes == -1) { + sink.complete(); + } else { + sink.next(Unpooled.copiedBuffer(chunkBytes, 0, consumedBytes)); + } + } catch (IOException e) { + sink.error(e); + } + return this; + } + + public void dispose() { + if (inputStream != null) { + try { + inputStream.close(); + } catch (IOException e) { + } + } + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java new file mode 100644 index 000000000..ba82c7c93 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/ResumeFileTransfer.java @@ -0,0 +1,119 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.resume; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.util.retry.Retry; + +public class ResumeFileTransfer { + + /*amount of file chunks requested by subscriber: n, refilled on n/2 of received items*/ + private static final int PREFETCH_WINDOW_SIZE = 4; + private static final Logger logger = LoggerFactory.getLogger(ResumeFileTransfer.class); + + public static void main(String[] args) { + + Resume resume = + new Resume() + .sessionDuration(Duration.ofMinutes(5)) + .retry( + Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)) + .doBeforeRetry(s -> logger.debug("Disconnected. Trying to resume..."))); + + RequestCodec codec = new RequestCodec(); + + CloseableChannel server = + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> { + Request request = codec.decode(payload); + payload.release(); + String fileName = request.getFileName(); + int chunkSize = request.getChunkSize(); + + Flux ticks = Flux.interval(Duration.ofMillis(500)).onBackpressureDrop(); + + return Files.fileSource(fileName, chunkSize) + .map(DefaultPayload::create) + .zipWith(ticks, (p, tick) -> p) + .log("server"); + })) + .resume(resume) + .bindNow(TcpServerTransport.create("localhost", 8000)); + + RSocket client = + RSocketConnector.create() + .resume(resume) + .connect(TcpClientTransport.create("localhost", 8001)) + .block(); + + client + .requestStream(codec.encode(new Request(16, "lorem.txt"))) + .log("client") + .doFinally(s -> server.dispose()) + .subscribe(Files.fileSink("rsocket-examples/build/lorem_output.txt", PREFETCH_WINDOW_SIZE)); + + server.onClose().block(); + } + + private static class RequestCodec { + + public Payload encode(Request request) { + String encoded = request.getChunkSize() + ":" + request.getFileName(); + return DefaultPayload.create(encoded); + } + + public Request decode(Payload payload) { + String encoded = payload.getDataUtf8(); + String[] chunkSizeAndFileName = encoded.split(":"); + int chunkSize = Integer.parseInt(chunkSizeAndFileName[0]); + String fileName = chunkSizeAndFileName[1]; + return new Request(chunkSize, fileName); + } + } + + private static class Request { + private final int chunkSize; + private final String fileName; + + public Request(int chunkSize, String fileName) { + this.chunkSize = chunkSize; + this.fileName = fileName; + } + + public int getChunkSize() { + return chunkSize; + } + + public String getFileName() { + return fileName; + } + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/readme.md b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/readme.md new file mode 100644 index 000000000..55e761fe8 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/resume/readme.md @@ -0,0 +1,29 @@ +1. Start socat. It is used for emulation of transport disconnects + +`socat -d TCP-LISTEN:8001,fork,reuseaddr TCP:localhost:8000` + +2. start `ResumeFileTransfer.main` + +3. terminate/start socat periodically for session resumption + +`ResumeFileTransfer` output is as follows + +``` +Received file chunk: 7. Total size: 112 +Received file chunk: 8. Total size: 128 +Received file chunk: 9. Total size: 144 +Received file chunk: 10. Total size: 160 +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Disconnected. Trying to resume connection... +Received file chunk: 11. Total size: 176 +Received file chunk: 12. Total size: 192 +Received file chunk: 13. Total size: 208 +Received file chunk: 14. Total size: 224 +Received file chunk: 15. Total size: 240 +Received file chunk: 16. Total size: 256 +``` + +It transfers file from `resources/lorem.txt` to `build/out/lorem_output.txt` in chunks of 16 bytes every 500 millis diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java new file mode 100644 index 000000000..af0df3be1 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ClientStreamingToServer.java @@ -0,0 +1,63 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.stream; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + +public final class ClientStreamingToServer { + + private static final Logger logger = LoggerFactory.getLogger(ClientStreamingToServer.class); + + public static void main(String[] args) throws InterruptedException { + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.interval(Duration.ofMillis(100)) + .map(aLong -> DefaultPayload.create("Interval: " + aLong)))) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket socket = + RSocketConnector.create() + .setupPayload(DefaultPayload.create("test", "test")) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + final Payload payload = DefaultPayload.create("Hello"); + socket + .requestStream(payload) + .map(Payload::getDataUtf8) + .doOnNext(logger::debug) + .take(10) + .then() + .doFinally(signalType -> socket.dispose()) + .then() + .block(); + + Thread.sleep(1000000); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java new file mode 100644 index 000000000..10ed34553 --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/tcp/stream/ServerStreamingToClient.java @@ -0,0 +1,60 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.tcp.stream; + +import static io.rsocket.SocketAcceptor.forRequestStream; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public final class ServerStreamingToClient { + + public static void main(String[] args) { + + RSocketServer.create( + (setup, rsocket) -> { + rsocket + .requestStream(DefaultPayload.create("Hello-Bidi")) + .map(Payload::getDataUtf8) + .log() + .subscribe(); + + return Mono.just(new RSocket() {}); + }) + .bindNow(TcpServerTransport.create("localhost", 7000)); + + RSocket rsocket = + RSocketConnector.create() + .acceptor( + forRequestStream( + payload -> + Flux.interval(Duration.ofSeconds(1)) + .map(aLong -> DefaultPayload.create("Bi-di Response => " + aLong)))) + .connect(TcpClientTransport.create("localhost", 7000)) + .block(); + + rsocket.onClose().block(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java new file mode 100644 index 000000000..89304853c --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketAggregationSample.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015-present the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.ws; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +public class WebSocketAggregationSample { + + private static final Logger logger = LoggerFactory.getLogger(WebSocketAggregationSample.class); + + public static void main(String[] args) { + + ServerTransport.ConnectionAcceptor connectionAcceptor = + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .asConnectionAcceptor(); + + DisposableServer server = + HttpServer.create() + .host("localhost") + .port(0) + .handle( + (req, res) -> + res.sendWebsocket( + (in, out) -> + connectionAcceptor + .apply( + new WebsocketDuplexConnection( + (Connection) in.aggregateFrames())) + .then(out.neverComplete()))) + .bindNow(); + + WebsocketClientTransport transport = + WebsocketClientTransport.create(server.host(), server.port()); + + RSocket clientRSocket = + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(transport) + .block(); + + Flux.range(1, 100) + .concatMap(i -> clientRSocket.requestResponse(ByteBufPayload.create("Hello " + i))) + .doOnNext(payload -> logger.debug("Processed " + payload.getDataUtf8())) + .blockLast(); + clientRSocket.dispose(); + server.dispose(); + } +} diff --git a/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java new file mode 100644 index 000000000..72e003d2a --- /dev/null +++ b/rsocket-examples/src/main/java/io/rsocket/examples/transport/ws/WebSocketHeadersSample.java @@ -0,0 +1,99 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.examples.transport.ws; + +import io.netty.handler.codec.http.HttpResponseStatus; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +public class WebSocketHeadersSample { + + private static final Logger logger = LoggerFactory.getLogger(WebSocketHeadersSample.class); + + public static void main(String[] args) { + + ServerTransport.ConnectionAcceptor connectionAcceptor = + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .asConnectionAcceptor(); + + DisposableServer server = + HttpServer.create() + .host("localhost") + .port(0) + .route( + routes -> + routes.get( + "/", + (req, res) -> { + if (req.requestHeaders().containsValue("Authorization", "test", true)) { + return res.sendWebsocket( + (in, out) -> + connectionAcceptor + .apply(new WebsocketDuplexConnection((Connection) in)) + .then(out.neverComplete())); + } + res.status(HttpResponseStatus.UNAUTHORIZED); + return res.send(); + })) + .bindNow(); + + logger.debug( + "\n\nStart of Authorized WebSocket Connection\n----------------------------------\n"); + + WebsocketClientTransport transport = + WebsocketClientTransport.create(server.host(), server.port()) + .header("Authorization", "test"); + + RSocket clientRSocket = + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(transport) + .block(); + + Flux.range(1, 100) + .concatMap(i -> clientRSocket.requestResponse(ByteBufPayload.create("Hello " + i))) + .doOnNext(payload -> logger.debug("Processed " + payload.getDataUtf8())) + .blockLast(); + clientRSocket.dispose(); + + logger.debug( + "\n\nStart of Unauthorized WebSocket Upgrade\n----------------------------------\n"); + + RSocketConnector.create() + .keepAlive(Duration.ofMinutes(10), Duration.ofMinutes(10)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(WebsocketClientTransport.create(server.host(), server.port())) + .block(); + } +} diff --git a/rsocket-examples/src/main/resources/logback.xml b/rsocket-examples/src/main/resources/logback.xml new file mode 100644 index 000000000..780a70c99 --- /dev/null +++ b/rsocket-examples/src/main/resources/logback.xml @@ -0,0 +1,34 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + + diff --git a/rsocket-examples/src/main/resources/lorem.txt b/rsocket-examples/src/main/resources/lorem.txt new file mode 100644 index 000000000..e035ea86d --- /dev/null +++ b/rsocket-examples/src/main/resources/lorem.txt @@ -0,0 +1,32 @@ +Alteration literature to or an sympathize mr imprudence. Of is ferrars subject as enjoyed or tedious cottage. +Procuring as in resembled by in agreeable. Next long no gave mr eyes. Admiration advantages no he celebrated so pianoforte unreserved. +Not its herself forming charmed amiable. Him why feebly expect future now. + +Situation admitting promotion at or to perceived be. Mr acuteness we as estimable enjoyment up. +An held late as felt know. Learn do allow solid to grave. Middleton suspicion age her attention. +Chiefly several bed its wishing. Is so moments on chamber pressed to. Doubtful yet way properly answered humanity its desirous. + Minuter believe service arrived civilly add all. Acuteness allowance an at eagerness favourite in extensive exquisite ye. + + Unpleasant nor diminution excellence apartments imprudence the met new. Draw part them he an to he roof only. + Music leave say doors him. Tore bred form if sigh case as do. Staying he no looking if do opinion. + Sentiments way understood end partiality and his. + + Ladyship it daughter securing procured or am moreover mr. Put sir she exercise vicinity cheerful wondered. + Continual say suspicion provision you neglected sir curiosity unwilling. Simplicity end themselves increasing led day sympathize yet. + General windows effects not are drawing man garrets. Common indeed garden you his ladies out yet. Preference imprudence contrasted to remarkably in on. + Taken now you him trees tears any. Her object giving end sister except oppose. + + No comfort do written conduct at prevent manners on. Celebrated contrasted discretion him sympathize her collecting occasional. + Do answered bachelor occasion in of offended no concerns. Supply worthy warmth branch of no ye. Voice tried known to as my to. + Though wished merits or be. Alone visit use these smart rooms ham. No waiting in on enjoyed placing it inquiry. + + So insisted received is occasion advanced honoured. Among ready to which up. Attacks smiling and may out assured moments man nothing outward. + Thrown any behind afford either the set depend one temper. Instrument melancholy in acceptance collecting frequently be if. + Zealously now pronounce existence add you instantly say offending. Merry their far had widen was. Concerns no in expenses raillery formerly. + + As am hastily invited settled at limited civilly fortune me. Really spring in extent an by. Judge but built gay party world. + Of so am he remember although required. Bachelor unpacked be advanced at. Confined in declared marianne is vicinity. + + In alteration insipidity impression by travelling reasonable up motionless. Of regard warmth by unable sudden garden ladies. + No kept hung am size spot no. Likewise led and dissuade rejoiced welcomed husbands boy. Do listening on he suspected resembled. + Water would still if to. Position boy required law moderate was may. \ No newline at end of file diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java new file mode 100644 index 000000000..ac311a231 --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/integration/IntegrationTest.java @@ -0,0 +1,208 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.plugins.SocketAcceptorInterceptor; +import io.rsocket.test.TestSubscriber; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.RSocketProxy; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class IntegrationTest { + + private static final RSocketInterceptor requesterInterceptor; + private static final RSocketInterceptor responderInterceptor; + private static final SocketAcceptorInterceptor clientAcceptorInterceptor; + private static final SocketAcceptorInterceptor serverAcceptorInterceptor; + private static final DuplexConnectionInterceptor connectionInterceptor; + + private static volatile boolean calledRequester = false; + private static volatile boolean calledResponder = false; + private static volatile boolean calledClientAcceptor = false; + private static volatile boolean calledServerAcceptor = false; + private static volatile boolean calledFrame = false; + + static { + requesterInterceptor = + reactiveSocket -> + new RSocketProxy(reactiveSocket) { + @Override + public Mono requestResponse(Payload payload) { + calledRequester = true; + return reactiveSocket.requestResponse(payload); + } + }; + + responderInterceptor = + reactiveSocket -> + new RSocketProxy(reactiveSocket) { + @Override + public Mono requestResponse(Payload payload) { + calledResponder = true; + return reactiveSocket.requestResponse(payload); + } + }; + + clientAcceptorInterceptor = + acceptor -> + (setup, sendingSocket) -> { + calledClientAcceptor = true; + return acceptor.accept(setup, sendingSocket); + }; + + serverAcceptorInterceptor = + acceptor -> + (setup, sendingSocket) -> { + calledServerAcceptor = true; + return acceptor.accept(setup, sendingSocket); + }; + + connectionInterceptor = + (type, connection) -> { + calledFrame = true; + return connection; + }; + } + + private CloseableChannel server; + private RSocket client; + private AtomicInteger requestCount; + private CountDownLatch disconnectionCounter; + private AtomicInteger errorCount; + + @BeforeEach + public void startup() { + errorCount = new AtomicInteger(); + requestCount = new AtomicInteger(); + disconnectionCounter = new CountDownLatch(1); + + server = + RSocketServer.create( + (setup, sendingSocket) -> { + sendingSocket + .onClose() + .doFinally(signalType -> disconnectionCounter.countDown()) + .subscribe(); + + return Mono.just( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + return Mono.just(DefaultPayload.create("RESPONSE", "METADATA")) + .doOnSubscribe(s -> requestCount.incrementAndGet()); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.range(1, 10_000) + .map(i -> DefaultPayload.create("data -> " + i)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads); + } + }); + }) + .interceptors( + registry -> + registry + .forResponder(responderInterceptor) + .forSocketAcceptor(serverAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .bind(TcpServerTransport.create("localhost", 0)) + .block(); + + client = + RSocketConnector.create() + .interceptors( + registry -> + registry + .forRequester(requesterInterceptor) + .forSocketAcceptor(clientAcceptorInterceptor) + .forConnection(connectionInterceptor)) + .connect(TcpClientTransport.create(server.address())) + .block(); + } + + @AfterEach + public void teardown() { + server.dispose(); + } + + @Test + @Timeout(5_000L) + public void testRequest() { + client.requestResponse(DefaultPayload.create("REQUEST", "META")).block(); + assertThat(requestCount).as("Server did not see the request.").hasValue(1); + + assertThat(calledRequester).isTrue(); + assertThat(calledResponder).isTrue(); + assertThat(calledClientAcceptor).isTrue(); + assertThat(calledServerAcceptor).isTrue(); + assertThat(calledFrame).isTrue(); + } + + @Test + @Timeout(5_000L) + public void testStream() { + Subscriber subscriber = TestSubscriber.createCancelling(); + client.requestStream(DefaultPayload.create("start")).subscribe(subscriber); + + verify(subscriber).onSubscribe(any()); + verifyNoMoreInteractions(subscriber); + } + + @Test + @Timeout(5_000L) + public void testClose() throws InterruptedException { + client.dispose(); + disconnectionCounter.await(); + } + + @Test // (timeout = 5_000L) + public void testCallRequestWithErrorAndThenRequest() { + assertThatThrownBy(client.requestChannel(Mono.error(new Throwable("test")))::blockLast) + .hasMessage("java.lang.Throwable: test"); + + testRequest(); + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java new file mode 100644 index 000000000..48e5baaa7 --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/integration/InteractionsLoadTest.java @@ -0,0 +1,104 @@ +package io.rsocket.integration; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.test.SlowTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.function.Supplier; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; + +public class InteractionsLoadTest { + + @Test + @SlowTest + public void channel() { + CloseableChannel server = + RSocketServer.create(SocketAcceptor.with(new EchoRSocket())) + .bind(TcpServerTransport.create("localhost", 0)) + .block(Duration.ofSeconds(10)); + + RSocket clientRSocket = + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) + .block(Duration.ofSeconds(10)); + + int concurrency = 16; + Flux.range(1, concurrency) + .flatMap( + v -> + clientRSocket + .requestChannel( + input().onBackpressureDrop().map(iv -> DefaultPayload.create("foo"))) + .limitRate(10000), + concurrency) + .timeout(Duration.ofSeconds(5)) + .doOnNext( + p -> { + String data = p.getDataUtf8(); + if (!data.equals("bar")) { + throw new IllegalStateException("Channel Client Bad message: " + data); + } + }) + .window(Duration.ofSeconds(1)) + .flatMap(Flux::count) + .doOnNext(d -> System.out.println("Got: " + d)) + .take(Duration.ofMinutes(1)) + .doOnTerminate(server::dispose) + .subscribe(); + + server.onClose().block(); + } + + private static Flux input() { + Flux interval = Flux.interval(Duration.ofMillis(1)).onBackpressureDrop(); + for (int i = 0; i < 10; i++) { + interval = interval.mergeWith(interval); + } + return interval; + } + + private static class EchoRSocket implements RSocket { + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .map( + p -> { + String data = p.getDataUtf8(); + if (!data.equals("foo")) { + throw new IllegalStateException("Channel Server Bad message: " + data); + } + return DefaultPayload.create("bar"); + }); + } + + @Override + public Flux requestStream(Payload payload) { + return Flux.just(payload) + .map( + p -> { + String data = p.getDataUtf8(); + return data; + }) + .doOnNext( + (data) -> { + if (!data.equals("foo")) { + throw new IllegalStateException("Stream Server Bad message: " + data); + } + }) + .flatMap( + data -> { + Supplier p = () -> DefaultPayload.create("bar"); + return Flux.range(1, 100).map(v -> p.get()); + }); + } + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java new file mode 100644 index 000000000..1924668fb --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TcpIntegrationTest.java @@ -0,0 +1,195 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import io.rsocket.util.RSocketProxy; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Schedulers; + +public class TcpIntegrationTest { + private RSocket handler; + + private CloseableChannel server; + + @BeforeEach + public void startup() { + server = + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) + .bind(TcpServerTransport.create("localhost", 0)) + .block(); + } + + private RSocket buildClient() { + return RSocketConnector.connectWith(TcpClientTransport.create(server.address())).block(); + } + + @AfterEach + public void cleanup() { + server.dispose(); + } + + @Test + @Timeout(15_000L) + public void testCompleteWithoutNext() { + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return Flux.empty(); + } + }; + RSocket client = buildClient(); + Boolean hasElements = + client.requestStream(DefaultPayload.create("REQUEST", "META")).log().hasElements().block(); + + assertThat(hasElements).isFalse(); + } + + @Test + @Timeout(15_000L) + public void testSingleStream() { + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return Flux.just(DefaultPayload.create("RESPONSE", "METADATA")); + } + }; + + RSocket client = buildClient(); + + Payload result = client.requestStream(DefaultPayload.create("REQUEST", "META")).blockLast(); + + assertThat(result.getDataUtf8()).isEqualTo("RESPONSE"); + } + + @Test + @Timeout(15_000L) + public void testZeroPayload() { + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return Flux.just(EmptyPayload.INSTANCE); + } + }; + + RSocket client = buildClient(); + + Payload result = client.requestStream(DefaultPayload.create("REQUEST", "META")).blockFirst(); + + assertThat(result.getDataUtf8()).isEmpty(); + } + + @Test + @Timeout(15_000L) + public void testRequestResponseErrors() { + handler = + new RSocket() { + boolean first = true; + + @Override + public Mono requestResponse(Payload payload) { + if (first) { + first = false; + return Mono.error(new RuntimeException("EX")); + } else { + return Mono.just(DefaultPayload.create("SUCCESS")); + } + } + }; + + RSocket client = buildClient(); + + Payload response1 = + client + .requestResponse(DefaultPayload.create("REQUEST", "META")) + .onErrorReturn(DefaultPayload.create("ERROR")) + .block(); + Payload response2 = + client + .requestResponse(DefaultPayload.create("REQUEST", "META")) + .onErrorReturn(DefaultPayload.create("ERROR")) + .block(); + + assertThat(response1.getDataUtf8()).isEqualTo("ERROR"); + assertThat(response2.getDataUtf8()).isEqualTo("SUCCESS"); + } + + @Test + @Timeout(15_000L) + public void testTwoConcurrentStreams() throws InterruptedException { + ConcurrentHashMap> map = new ConcurrentHashMap<>(); + Sinks.Many processor1 = Sinks.many().unicast().onBackpressureBuffer(); + map.put("REQUEST1", processor1); + Sinks.Many processor2 = Sinks.many().unicast().onBackpressureBuffer(); + map.put("REQUEST2", processor2); + + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + return map.get(payload.getDataUtf8()).asFlux(); + } + }; + + RSocket client = buildClient(); + + Flux response1 = client.requestStream(DefaultPayload.create("REQUEST1")); + Flux response2 = client.requestStream(DefaultPayload.create("REQUEST2")); + + CountDownLatch nextCountdown = new CountDownLatch(2); + CountDownLatch completeCountdown = new CountDownLatch(2); + + response1 + .subscribeOn(Schedulers.newSingle("1")) + .subscribe(c -> nextCountdown.countDown(), t -> {}, completeCountdown::countDown); + + response2 + .subscribeOn(Schedulers.newSingle("2")) + .subscribe(c -> nextCountdown.countDown(), t -> {}, completeCountdown::countDown); + + processor1.tryEmitNext(DefaultPayload.create("RESPONSE1A")); + processor2.tryEmitNext(DefaultPayload.create("RESPONSE2A")); + + nextCountdown.await(); + + processor1.tryEmitComplete(); + processor2.tryEmitComplete(); + + completeCountdown.await(); + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java new file mode 100644 index 000000000..cd96584ed --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/integration/TestingStreaming.java @@ -0,0 +1,124 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration; + +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.exceptions.ApplicationErrorException; +import io.rsocket.transport.local.LocalClientTransport; +import io.rsocket.transport.local.LocalServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; + +public class TestingStreaming { + LocalServerTransport serverTransport = LocalServerTransport.create("test"); + + @Test + public void testRangeButThrowException() { + Closeable server = null; + try { + server = + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 1000) + .doOnNext( + i -> { + if (i > 3) { + throw new RuntimeException("BOOM!"); + } + }) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) + .block(); + + Assertions.assertThatThrownBy( + Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i))::blockLast) + .isInstanceOf(ApplicationErrorException.class); + + } finally { + server.dispose(); + } + } + + @Test + public void testRangeOfConsumers() { + Closeable server = null; + try { + server = + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 1000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) + .block(); + + Flux.range(1, 6).flatMap(i -> consumer("connection number -> " + i)).blockLast(); + } finally { + server.dispose(); + } + } + + private Flux consumer(String s) { + return RSocketConnector.connectWith(LocalClientTransport.create("test")) + .flatMapMany( + rSocket -> { + AtomicInteger count = new AtomicInteger(); + return Flux.range(1, 100) + .flatMap( + i -> rSocket.requestStream(DefaultPayload.create("i -> " + i)).take(100), 1); + }); + } + + @Test + public void testSingleConsumer() { + Closeable server = null; + try { + server = + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(1, 10_000) + .map(l -> DefaultPayload.create("l -> " + l)) + .cast(Payload.class))) + .bind(serverTransport) + .block(); + + consumer("1").blockLast(); + + } finally { + server.dispose(); + } + } + + @Test + public void testFluxOnly() { + Flux longFlux = Flux.interval(Duration.ofMillis(1)).onBackpressureDrop(); + + Flux.range(1, 60).flatMap(i -> longFlux.take(1000)).blockLast(); + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java new file mode 100644 index 000000000..870ecf0cd --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/integration/observation/ObservationIntegrationTest.java @@ -0,0 +1,246 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration.observation; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import io.micrometer.core.instrument.observation.DefaultMeterObservationHandler; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.micrometer.core.tck.MeterRegistryAssert; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationHandler; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.tracing.test.SampleTestRunner; +import io.micrometer.tracing.test.reporter.BuildingBlocks; +import io.micrometer.tracing.test.simple.SpansAssert; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.micrometer.observation.ByteBufGetter; +import io.rsocket.micrometer.observation.ByteBufSetter; +import io.rsocket.micrometer.observation.ObservationRequesterRSocketProxy; +import io.rsocket.micrometer.observation.ObservationResponderRSocketProxy; +import io.rsocket.micrometer.observation.RSocketRequesterTracingObservationHandler; +import io.rsocket.micrometer.observation.RSocketResponderTracingObservationHandler; +import io.rsocket.plugins.RSocketInterceptor; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.Deque; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class ObservationIntegrationTest extends SampleTestRunner { + private static final MeterRegistry registry = new SimpleMeterRegistry(); + private static final ObservationRegistry observationRegistry = ObservationRegistry.create(); + + static { + observationRegistry + .observationConfig() + .observationHandler(new DefaultMeterObservationHandler(registry)); + } + + private final RSocketInterceptor requesterInterceptor; + private final RSocketInterceptor responderInterceptor; + + ObservationIntegrationTest() { + super(SampleRunnerConfig.builder().build()); + requesterInterceptor = + reactiveSocket -> new ObservationRequesterRSocketProxy(reactiveSocket, observationRegistry); + + responderInterceptor = + reactiveSocket -> new ObservationResponderRSocketProxy(reactiveSocket, observationRegistry); + } + + private CloseableChannel server; + private RSocket client; + private AtomicInteger counter; + + @Override + public BiConsumer>> + customizeObservationHandlers() { + return (buildingBlocks, observationHandlers) -> { + observationHandlers.addFirst( + new RSocketRequesterTracingObservationHandler( + buildingBlocks.getTracer(), + buildingBlocks.getPropagator(), + new ByteBufSetter(), + false)); + observationHandlers.addFirst( + new RSocketResponderTracingObservationHandler( + buildingBlocks.getTracer(), + buildingBlocks.getPropagator(), + new ByteBufGetter(), + false)); + }; + } + + @AfterEach + public void teardown() { + if (server != null) { + server.dispose(); + } + } + + private void testRequest() { + counter.set(0); + client.requestResponse(DefaultPayload.create("REQUEST", "META")).block(); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testStream() { + counter.set(0); + client.requestStream(DefaultPayload.create("start")).blockLast(); + + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testRequestChannel() { + counter.set(0); + client.requestChannel(Mono.just(DefaultPayload.create("start"))).blockFirst(); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + private void testFireAndForget() { + counter.set(0); + client.fireAndForget(DefaultPayload.create("start")).subscribe(); + Awaitility.await().atMost(Duration.ofSeconds(50)).until(() -> counter.get() == 1); + assertThat(counter).as("Server did not see the request.").hasValue(1); + } + + @Override + public SampleTestRunnerConsumer yourCode() { + return (bb, meterRegistry) -> { + counter = new AtomicInteger(); + server = + RSocketServer.create( + (setup, sendingSocket) -> { + sendingSocket.onClose().subscribe(); + + return Mono.just( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Mono.just(DefaultPayload.create("RESPONSE", "METADATA")); + } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Flux.range(1, 10_000) + .map(i -> DefaultPayload.create("data -> " + i)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + counter.incrementAndGet(); + return Flux.from(payloads); + } + + @Override + public Mono fireAndForget(Payload payload) { + payload.release(); + counter.incrementAndGet(); + return Mono.empty(); + } + }); + }) + .interceptors(registry -> registry.forResponder(responderInterceptor)) + .bind(TcpServerTransport.create("localhost", 0)) + .block(); + + client = + RSocketConnector.create() + .interceptors(registry -> registry.forRequester(requesterInterceptor)) + .connect(TcpClientTransport.create(server.address())) + .block(); + + testRequest(); + + testStream(); + + testRequestChannel(); + + testFireAndForget(); + + // @formatter:off + SpansAssert.assertThat(bb.getFinishedSpans()) + .haveSameTraceId() + // "request_*" + "handle" x 4 + .hasNumberOfSpansEqualTo(8) + .hasNumberOfSpansWithNameEqualTo("handle", 4) + .forAllSpansWithNameEqualTo("handle", span -> span.hasTagWithKey("rsocket.request-type")) + .hasASpanWithNameIgnoreCase("request_stream") + .thenASpanWithNameEqualToIgnoreCase("request_stream") + .hasTag("rsocket.request-type", "REQUEST_STREAM") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_channel") + .thenASpanWithNameEqualToIgnoreCase("request_channel") + .hasTag("rsocket.request-type", "REQUEST_CHANNEL") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_fnf") + .thenASpanWithNameEqualToIgnoreCase("request_fnf") + .hasTag("rsocket.request-type", "REQUEST_FNF") + .backToSpans() + .hasASpanWithNameIgnoreCase("request_response") + .thenASpanWithNameEqualToIgnoreCase("request_response") + .hasTag("rsocket.request-type", "REQUEST_RESPONSE"); + + MeterRegistryAssert.assertThat(registry) + .hasTimerWithNameAndTags( + "rsocket.response", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_RESPONSE"))) + .hasTimerWithNameAndTags( + "rsocket.fnf", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_FNF"))) + .hasTimerWithNameAndTags( + "rsocket.request", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_RESPONSE"))) + .hasTimerWithNameAndTags( + "rsocket.channel", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_CHANNEL"))) + .hasTimerWithNameAndTags( + "rsocket.stream", + Tags.of(Tag.of("error", "none"), Tag.of("rsocket.request-type", "REQUEST_STREAM"))); + // @formatter:on + }; + } + + @Override + protected MeterRegistry getMeterRegistry() { + return registry; + } + + @Override + protected ObservationRegistry getObservationRegistry() { + return observationRegistry; + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/DisconnectableClientTransport.java b/rsocket-examples/src/test/java/io/rsocket/resume/DisconnectableClientTransport.java new file mode 100644 index 000000000..5824918bc --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/resume/DisconnectableClientTransport.java @@ -0,0 +1,75 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.rsocket.DuplexConnection; +import io.rsocket.transport.ClientTransport; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicReference; +import reactor.core.publisher.Mono; + +class DisconnectableClientTransport implements ClientTransport { + private final ClientTransport clientTransport; + private final AtomicReference curConnection = new AtomicReference<>(); + private long nextConnectPermitMillis; + + public DisconnectableClientTransport(ClientTransport clientTransport) { + this.clientTransport = clientTransport; + } + + @Override + public Mono connect() { + return Mono.defer( + () -> + now() < nextConnectPermitMillis + ? Mono.error(new ClosedChannelException()) + : clientTransport + .connect() + .map( + c -> { + if (curConnection.compareAndSet(null, c)) { + return c; + } else { + throw new IllegalStateException( + "Transport supports at most 1 connection"); + } + })); + } + + public void disconnect() { + disconnectFor(Duration.ZERO); + } + + public void disconnectPermanently() { + disconnectFor(Duration.ofDays(42)); + } + + public void disconnectFor(Duration cooldown) { + DuplexConnection cur = curConnection.getAndSet(null); + if (cur != null) { + nextConnectPermitMillis = now() + cooldown.toMillis(); + cur.dispose(); + } else { + throw new IllegalStateException("Trying to disconnect while not connected"); + } + } + + private static long now() { + return System.currentTimeMillis(); + } +} diff --git a/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java new file mode 100644 index 000000000..5eb78fabe --- /dev/null +++ b/rsocket-examples/src/test/java/io/rsocket/resume/ResumeIntegrationTest.java @@ -0,0 +1,229 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.resume; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.exceptions.RejectedResumeException; +import io.rsocket.exceptions.UnsupportedSetupException; +import io.rsocket.test.SlowTest; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.util.retry.Retry; + +@SlowTest +public class ResumeIntegrationTest { + private static final String SERVER_HOST = "localhost"; + private static final int SERVER_PORT = 0; + + @Test + void timeoutOnPermanentDisconnect() { + CloseableChannel closeable = newServerRSocket().block(); + + DisconnectableClientTransport clientTransport = + new DisconnectableClientTransport(clientTransport(closeable.address())); + + int sessionDurationSeconds = 5; + RSocket rSocket = newClientRSocket(clientTransport, sessionDurationSeconds).block(); + + Mono.delay(Duration.ofSeconds(1)).subscribe(v -> clientTransport.disconnectPermanently()); + + StepVerifier.create( + rSocket.requestChannel(testRequest()).then().doFinally(s -> closeable.dispose())) + .expectError(ClosedChannelException.class) + .verify(Duration.ofSeconds(7)); + } + + @Test + public void reconnectOnDisconnect() { + CloseableChannel closeable = newServerRSocket().block(); + + DisconnectableClientTransport clientTransport = + new DisconnectableClientTransport(clientTransport(closeable.address())); + + int sessionDurationSeconds = 15; + RSocket rSocket = newClientRSocket(clientTransport, sessionDurationSeconds).block(); + + Flux.just(3, 20, 40, 75) + .flatMap(v -> Mono.delay(Duration.ofSeconds(v))) + .subscribe(v -> clientTransport.disconnectFor(Duration.ofSeconds(7))); + + AtomicInteger counter = new AtomicInteger(-1); + StepVerifier.create( + rSocket + .requestChannel(testRequest()) + .take(Duration.ofSeconds(600)) + .map(Payload::getDataUtf8) + .timeout(Duration.ofSeconds(12)) + .doOnNext(x -> throwOnNonContinuous(counter, x)) + .then() + .doFinally(s -> closeable.dispose())) + .expectComplete() + .verify(); + } + + @Test + public void reconnectOnMissingSession() { + + int serverSessionDuration = 2; + + CloseableChannel closeable = newServerRSocket(serverSessionDuration).block(); + + DisconnectableClientTransport clientTransport = + new DisconnectableClientTransport(clientTransport(closeable.address())); + int clientSessionDurationSeconds = 10; + + RSocket rSocket = newClientRSocket(clientTransport, clientSessionDurationSeconds).block(); + + Mono.delay(Duration.ofSeconds(1)) + .subscribe(v -> clientTransport.disconnectFor(Duration.ofSeconds(3))); + + StepVerifier.create( + rSocket.requestChannel(testRequest()).then().doFinally(s -> closeable.dispose())) + .expectError() + .verify(Duration.ofSeconds(5)); + + StepVerifier.create(rSocket.onClose()) + .expectErrorMatches( + err -> + err instanceof RejectedResumeException + && "unknown resume token".equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + } + + @Test + void serverMissingResume() { + CloseableChannel closeableChannel = + RSocketServer.create(SocketAcceptor.with(new TestResponderRSocket())) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)) + .block(); + + RSocket rSocket = + RSocketConnector.create() + .resume(new Resume()) + .connect(clientTransport(closeableChannel.address())) + .block(); + + StepVerifier.create(rSocket.onClose().doFinally(s -> closeableChannel.dispose())) + .expectErrorMatches( + err -> + err instanceof UnsupportedSetupException + && "resume not supported".equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + + Assertions.assertThat(rSocket.isDisposed()).isTrue(); + } + + static ClientTransport clientTransport(InetSocketAddress address) { + return TcpClientTransport.create(address); + } + + static ServerTransport serverTransport(String host, int port) { + return TcpServerTransport.create(host, port); + } + + private static Flux testRequest() { + return Flux.interval(Duration.ofMillis(500)) + .map(v -> DefaultPayload.create("client_request")) + .onBackpressureDrop(); + } + + private void throwOnNonContinuous(AtomicInteger counter, String x) { + int curValue = Integer.parseInt(x); + int prevValue = counter.get(); + if (prevValue >= 0) { + int dif = curValue - prevValue; + if (dif != 1) { + throw new IllegalStateException( + String.format( + "Payload values are expected to be continuous numbers: %d %d", + prevValue, curValue)); + } + } + counter.set(curValue); + } + + private static Mono newClientRSocket( + DisconnectableClientTransport clientTransport, int sessionDurationSeconds) { + return RSocketConnector.create() + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .storeFactory(t -> new InMemoryResumableFramesStore("client", t, 500_000)) + .cleanupStoreOnKeepAlive() + .retry(Retry.fixedDelay(Long.MAX_VALUE, Duration.ofSeconds(1)))) + .keepAlive(Duration.ofSeconds(5), Duration.ofMinutes(5)) + .connect(clientTransport); + } + + private static Mono newServerRSocket() { + return newServerRSocket(15); + } + + private static Mono newServerRSocket(int sessionDurationSeconds) { + return RSocketServer.create(SocketAcceptor.with(new TestResponderRSocket())) + .resume( + new Resume() + .sessionDuration(Duration.ofSeconds(sessionDurationSeconds)) + .cleanupStoreOnKeepAlive() + .storeFactory(t -> new InMemoryResumableFramesStore("server", t, 500_000))) + .bind(serverTransport(SERVER_HOST, SERVER_PORT)); + } + + private static class TestResponderRSocket implements RSocket { + + AtomicInteger counter = new AtomicInteger(); + + @Override + public Flux requestChannel(Publisher payloads) { + return duplicate( + Flux.interval(Duration.ofMillis(1)) + .onBackpressureLatest() + .publishOn(Schedulers.boundedElastic()), + 20) + .map(v -> DefaultPayload.create(String.valueOf(counter.getAndIncrement()))) + .takeUntilOther(Flux.from(payloads).then()); + } + + private Flux duplicate(Flux f, int n) { + Flux r = Flux.empty(); + for (int i = 0; i < n; i++) { + r = r.mergeWith(f); + } + return r; + } + } +} diff --git a/rsocket-examples/src/test/resources/logback-test.xml b/rsocket-examples/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-examples/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-load-balancer/build.gradle b/rsocket-load-balancer/build.gradle new file mode 100644 index 000000000..6d91324ae --- /dev/null +++ b/rsocket-load-balancer/build.gradle @@ -0,0 +1,39 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' +} + +dependencies { + api project(':rsocket-core') + + implementation 'org.slf4j:slf4j-api' + + testImplementation project(':rsocket-test') + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.junit.jupiter:junit-jupiter-params' + testImplementation 'org.mockito:mockito-core' + testImplementation 'org.assertj:assertj-core' + testImplementation 'io.projectreactor:reactor-test' + + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' + testRuntimeOnly 'ch.qos.logback:logback-classic' +} + +description = 'Transparent Load Balancer for RSocket' diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java new file mode 100644 index 000000000..6329da826 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancedRSocketMono.java @@ -0,0 +1,989 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client; + +import io.rsocket.*; +import io.rsocket.client.filter.RSocketSupplier; +import io.rsocket.stat.Ewma; +import io.rsocket.stat.FrugalQuantile; +import io.rsocket.stat.Median; +import io.rsocket.stat.Quantile; +import io.rsocket.util.Clock; +import java.nio.channels.ClosedChannelException; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Optional; +import java.util.Random; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.CoreSubscriber; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Operators; +import reactor.util.context.Context; +import reactor.util.retry.Retry; + +/** + * An implementation of {@link Mono} that load balances across a pool of RSockets and emits one when + * it is subscribed to + * + *

It estimates the load of each RSocket based on statistics collected. + * + * @deprecated as of 1.1. in favor of {@link io.rsocket.loadbalance.LoadbalanceRSocketClient}. + */ +@Deprecated +public abstract class LoadBalancedRSocketMono extends Mono + implements Availability, Closeable { + + public static final double DEFAULT_EXP_FACTOR = 4.0; + public static final double DEFAULT_LOWER_QUANTILE = 0.2; + public static final double DEFAULT_HIGHER_QUANTILE = 0.8; + public static final double DEFAULT_MIN_PENDING = 1.0; + public static final double DEFAULT_MAX_PENDING = 2.0; + public static final int DEFAULT_MIN_APERTURE = 3; + public static final int DEFAULT_MAX_APERTURE = 100; + public static final long DEFAULT_MAX_REFRESH_PERIOD_MS = + TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES); + private static final Logger logger = LoggerFactory.getLogger(LoadBalancedRSocketMono.class); + private static final long APERTURE_REFRESH_PERIOD = Clock.unit().convert(15, TimeUnit.SECONDS); + private static final int EFFORT = 5; + private static final long DEFAULT_INITIAL_INTER_ARRIVAL_TIME = + Clock.unit().convert(1L, TimeUnit.SECONDS); + private static final int DEFAULT_INTER_ARRIVAL_FACTOR = 500; + + private static final FailingRSocket FAILING_REACTIVE_SOCKET = new FailingRSocket(); + protected final Mono rSocketMono; + private final double minPendings; + private final double maxPendings; + private final int minAperture; + private final int maxAperture; + private final long maxRefreshPeriod; + private final double expFactor; + private final Quantile lowerQuantile; + private final Quantile higherQuantile; + private final ArrayList activeSockets; + private final Ewma pendings; + private final MonoProcessor onClose = MonoProcessor.create(); + private final RSocketSupplierPool pool; + private final long weightedSocketRetries; + private final Duration weightedSocketBackOff; + private final Duration weightedSocketMaxBackOff; + private volatile int targetAperture; + private long lastApertureRefresh; + private long refreshPeriod; + private int pendingSockets; + private volatile long lastRefresh; + + /** + * @param factories the source (factories) of RSocket + * @param expFactor how aggressive is the algorithm toward outliers. A higher number means we send + * aggressively less traffic to a server slightly slower. + * @param lowQuantile the lower bound of the latency band of acceptable values. Any server below + * that value will be aggressively favored. + * @param highQuantile the higher bound of the latency band of acceptable values. Any server above + * that value will be aggressively penalized. + * @param minPendings The lower band of the average outstanding messages per server. + * @param maxPendings The higher band of the average outstanding messages per server. + * @param minAperture the minimum number of connections we want to maintain, independently of the + * load. + * @param maxAperture the maximum number of connections we want to maintain, independently of the + * load. + * @param maxRefreshPeriodMs the maximum time between two "refreshes" of the list of active + * RSocket. This is at that time that the slowest RSocket is closed. (unit is millisecond) + * @param weightedSocketRetries the number of times a weighted socket will attempt to retry when + * it receives an error before reconnecting. The default is 5 times. + * @param weightedSocketBackOff the duration a a weighted socket will add to each retry attempt. + * @param weightedSocketMaxBackOff the max duration a weighted socket will delay before retrying + * to connect. The default is 5 seconds. + */ + private LoadBalancedRSocketMono( + Publisher> factories, + double expFactor, + double lowQuantile, + double highQuantile, + double minPendings, + double maxPendings, + int minAperture, + int maxAperture, + long maxRefreshPeriodMs, + long weightedSocketRetries, + Duration weightedSocketBackOff, + Duration weightedSocketMaxBackOff) { + this.weightedSocketRetries = weightedSocketRetries; + this.weightedSocketBackOff = weightedSocketBackOff; + this.weightedSocketMaxBackOff = weightedSocketMaxBackOff; + this.expFactor = expFactor; + this.lowerQuantile = new FrugalQuantile(lowQuantile); + this.higherQuantile = new FrugalQuantile(highQuantile); + + this.activeSockets = new ArrayList<>(); + this.pendingSockets = 0; + + this.minPendings = minPendings; + this.maxPendings = maxPendings; + this.pendings = new Ewma(15, TimeUnit.SECONDS, (minPendings + maxPendings) / 2.0); + + this.minAperture = minAperture; + this.maxAperture = maxAperture; + this.targetAperture = minAperture; + + this.maxRefreshPeriod = Clock.unit().convert(maxRefreshPeriodMs, TimeUnit.MILLISECONDS); + this.lastApertureRefresh = Clock.now(); + this.refreshPeriod = Clock.unit().convert(15L, TimeUnit.SECONDS); + this.lastRefresh = Clock.now(); + this.pool = new RSocketSupplierPool(factories); + refreshSockets(); + + rSocketMono = Mono.fromSupplier(this::select); + + onClose.doFinally(signalType -> pool.dispose()).subscribe(); + } + + public static LoadBalancedRSocketMono create( + Publisher> factories) { + return create( + factories, + DEFAULT_EXP_FACTOR, + DEFAULT_LOWER_QUANTILE, + DEFAULT_HIGHER_QUANTILE, + DEFAULT_MIN_PENDING, + DEFAULT_MAX_PENDING, + DEFAULT_MIN_APERTURE, + DEFAULT_MAX_APERTURE, + DEFAULT_MAX_REFRESH_PERIOD_MS); + } + + public static LoadBalancedRSocketMono create( + Publisher> factories, + double expFactor, + double lowQuantile, + double highQuantile, + double minPendings, + double maxPendings, + int minAperture, + int maxAperture, + long maxRefreshPeriodMs, + long weightedSocketRetries, + Duration weightedSocketBackOff, + Duration weightedSocketMaxBackOff) { + return new LoadBalancedRSocketMono( + factories, + expFactor, + lowQuantile, + highQuantile, + minPendings, + maxPendings, + minAperture, + maxAperture, + maxRefreshPeriodMs, + weightedSocketRetries, + weightedSocketBackOff, + weightedSocketMaxBackOff) { + @Override + public void subscribe(CoreSubscriber s) { + rSocketMono.subscribe(s); + } + }; + } + + public static LoadBalancedRSocketMono create( + Publisher> factories, + double expFactor, + double lowQuantile, + double highQuantile, + double minPendings, + double maxPendings, + int minAperture, + int maxAperture, + long maxRefreshPeriodMs) { + return new LoadBalancedRSocketMono( + factories, + expFactor, + lowQuantile, + highQuantile, + minPendings, + maxPendings, + minAperture, + maxAperture, + maxRefreshPeriodMs, + 5, + Duration.ofMillis(500), + Duration.ofSeconds(5)) { + @Override + public void subscribe(CoreSubscriber s) { + rSocketMono.subscribe(s); + } + }; + } + + /** + * Responsible for: - refreshing the aperture - asynchronously adding/removing reactive sockets to + * match targetAperture - periodically append a new connection + */ + private synchronized void refreshSockets() { + refreshAperture(); + int n = activeSockets.size(); + if (n < targetAperture && !pool.isPoolEmpty()) { + logger.debug( + "aperture {} is below target {}, adding {} sockets", + n, + targetAperture, + targetAperture - n); + addSockets(targetAperture - n); + } else if (targetAperture < activeSockets.size()) { + logger.debug("aperture {} is above target {}, quicking 1 socket", n, targetAperture); + quickSlowestRS(); + } + + long now = Clock.now(); + if (now - lastRefresh >= refreshPeriod) { + long prev = refreshPeriod; + refreshPeriod = (long) Math.min(refreshPeriod * 1.5, maxRefreshPeriod); + logger.debug("Bumping refresh period, {}->{}", prev / 1000, refreshPeriod / 1000); + lastRefresh = now; + addSockets(1); + } + } + + private synchronized void addSockets(int numberOfNewSocket) { + int n = numberOfNewSocket; + int poolSize = pool.poolSize(); + if (n > poolSize) { + n = poolSize; + logger.debug( + "addSockets({}) restricted by the number of factories, i.e. addSockets({})", + numberOfNewSocket, + n); + } + + for (int i = 0; i < n; i++) { + Optional optional = pool.get(); + + if (optional.isPresent()) { + RSocketSupplier supplier = optional.get(); + WeightedSocket socket = new WeightedSocket(supplier, lowerQuantile, higherQuantile); + } else { + break; + } + } + } + + private synchronized void refreshAperture() { + int n = activeSockets.size(); + if (n == 0) { + return; + } + + double p = 0.0; + for (WeightedSocket wrs : activeSockets) { + p += wrs.getPending(); + } + p /= n + pendingSockets; + pendings.insert(p); + double avgPending = pendings.value(); + + long now = Clock.now(); + boolean underRateLimit = now - lastApertureRefresh > APERTURE_REFRESH_PERIOD; + if (avgPending < 1.0 && underRateLimit) { + updateAperture(targetAperture - 1, now); + } else if (2.0 < avgPending && underRateLimit) { + updateAperture(targetAperture + 1, now); + } + } + + /** + * Update the aperture value and ensure its value stays in the right range. + * + * @param newValue new aperture value + * @param now time of the change (for rate limiting purposes) + */ + private void updateAperture(int newValue, long now) { + int previous = targetAperture; + targetAperture = newValue; + targetAperture = Math.max(minAperture, targetAperture); + int maxAperture = Math.min(this.maxAperture, activeSockets.size() + pool.poolSize()); + targetAperture = Math.min(maxAperture, targetAperture); + lastApertureRefresh = now; + pendings.reset((minPendings + maxPendings) / 2); + + if (targetAperture != previous) { + logger.debug( + "Current pending={}, new target={}, previous target={}", + pendings.value(), + targetAperture, + previous); + } + } + + private synchronized void quickSlowestRS() { + if (activeSockets.size() <= 1) { + return; + } + + WeightedSocket slowest = null; + double lowestAvailability = Double.MAX_VALUE; + for (WeightedSocket socket : activeSockets) { + double load = socket.availability(); + if (load == 0.0) { + slowest = socket; + break; + } + if (socket.getPredictedLatency() != 0) { + load *= 1.0 / socket.getPredictedLatency(); + } + if (load < lowestAvailability) { + lowestAvailability = load; + slowest = socket; + } + } + + if (slowest != null) { + logger.debug("Disposing slowest WeightedSocket {}", slowest); + slowest.dispose(); + } + } + + @Override + public synchronized double availability() { + double currentAvailability = 0.0; + if (!activeSockets.isEmpty()) { + for (WeightedSocket rs : activeSockets) { + currentAvailability += rs.availability(); + } + currentAvailability /= activeSockets.size(); + } + + return currentAvailability; + } + + private synchronized RSocket select() { + refreshSockets(); + + if (activeSockets.isEmpty()) { + return FAILING_REACTIVE_SOCKET; + } + + int size = activeSockets.size(); + if (size == 1) { + return activeSockets.get(0); + } + + WeightedSocket rsc1 = null; + WeightedSocket rsc2 = null; + + Random rng = ThreadLocalRandom.current(); + for (int i = 0; i < EFFORT; i++) { + int i1 = rng.nextInt(size); + int i2 = rng.nextInt(size - 1); + if (i2 >= i1) { + i2++; + } + rsc1 = activeSockets.get(i1); + rsc2 = activeSockets.get(i2); + if (rsc1.availability() > 0.0 && rsc2.availability() > 0.0) { + break; + } + if (i + 1 == EFFORT && !pool.isPoolEmpty()) { + addSockets(1); + } + } + + double w1 = algorithmicWeight(rsc1); + double w2 = algorithmicWeight(rsc2); + if (w1 < w2) { + return rsc2; + } else { + return rsc1; + } + } + + private double algorithmicWeight(WeightedSocket socket) { + if (socket == null || socket.availability() == 0.0) { + return 0.0; + } + + int pendings = socket.getPending(); + double latency = socket.getPredictedLatency(); + + double low = lowerQuantile.estimation(); + double high = + Math.max( + higherQuantile.estimation(), + low * 1.001); // ensure higherQuantile > lowerQuantile + .1% + double bandWidth = Math.max(high - low, 1); + + if (latency < low) { + double alpha = (low - latency) / bandWidth; + double bonusFactor = Math.pow(1 + alpha, expFactor); + latency /= bonusFactor; + } else if (latency > high) { + double alpha = (latency - high) / bandWidth; + double penaltyFactor = Math.pow(1 + alpha, expFactor); + latency *= penaltyFactor; + } + + return socket.availability() * 1.0 / (1.0 + latency * (pendings + 1)); + } + + @Override + public synchronized String toString() { + return "LoadBalancer(a:" + + activeSockets.size() + + ", f: " + + pool.poolSize() + + ", avgPendings=" + + pendings.value() + + ", targetAperture=" + + targetAperture + + ", band=[" + + lowerQuantile.estimation() + + ", " + + higherQuantile.estimation() + + "])"; + } + + @Override + public void dispose() { + synchronized (this) { + activeSockets.forEach(WeightedSocket::dispose); + activeSockets.clear(); + onClose.onComplete(); + } + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; + } + + /** + * (Null Object Pattern) This failing RSocket never succeed, it is useful for simplifying the code + * when dealing with edge cases. + */ + private static class FailingRSocket implements RSocket { + + private static final Mono errorVoid = Mono.error(NoAvailableRSocketException.INSTANCE); + private static final Mono errorPayload = + Mono.error(NoAvailableRSocketException.INSTANCE); + + @Override + public Mono fireAndForget(Payload payload) { + return errorVoid; + } + + @Override + public Mono requestResponse(Payload payload) { + return errorPayload; + } + + @Override + public Flux requestStream(Payload payload) { + return errorPayload.flux(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return errorPayload.flux(); + } + + @Override + public Mono metadataPush(Payload payload) { + return errorVoid; + } + + @Override + public double availability() { + return 0; + } + + @Override + public void dispose() {} + + @Override + public boolean isDisposed() { + return true; + } + + @Override + public Mono onClose() { + return Mono.empty(); + } + } + + /** + * Wrapper of a RSocket, it computes statistics about the req/resp calls and update availability + * accordingly. + */ + private class WeightedSocket implements LoadBalancerSocketMetrics, RSocket { + + private static final double STARTUP_PENALTY = Long.MAX_VALUE >> 12; + private final Quantile lowerQuantile; + private final Quantile higherQuantile; + private final long inactivityFactor; + private final MonoProcessor rSocketMono; + private volatile int pending; // instantaneous rate + private long stamp; // last timestamp we sent a request + private long stamp0; // last timestamp we sent a request or receive a response + private long duration; // instantaneous cumulative duration + + private Median median; + private Ewma interArrivalTime; + + private AtomicLong pendingStreams; // number of active streams + + private volatile double availability = 0.0; + private final MonoProcessor onClose = MonoProcessor.create(); + + WeightedSocket( + RSocketSupplier factory, + Quantile lowerQuantile, + Quantile higherQuantile, + int inactivityFactor) { + this.rSocketMono = MonoProcessor.create(); + this.lowerQuantile = lowerQuantile; + this.higherQuantile = higherQuantile; + this.inactivityFactor = inactivityFactor; + long now = Clock.now(); + this.stamp = now; + this.stamp0 = now; + this.duration = 0L; + this.pending = 0; + this.median = new Median(); + this.interArrivalTime = new Ewma(1, TimeUnit.MINUTES, DEFAULT_INITIAL_INTER_ARRIVAL_TIME); + this.pendingStreams = new AtomicLong(); + + logger.debug("Creating WeightedSocket {} from factory {}", WeightedSocket.this, factory); + + WeightedSocket.this + .onClose() + .doFinally( + s -> { + pool.accept(factory); + activeSockets.remove(WeightedSocket.this); + logger.debug( + "Removed {} from factory {} from activeSockets", WeightedSocket.this, factory); + }) + .subscribe(); + + factory + .get() + .retryWhen( + Retry.backoff(weightedSocketRetries, weightedSocketBackOff) + .maxBackoff(weightedSocketMaxBackOff)) + .doOnError( + throwable -> { + logger.error( + "error while connecting {} from factory {}", + WeightedSocket.this, + factory, + throwable); + WeightedSocket.this.dispose(); + }) + .subscribe( + rSocket -> { + // When RSocket is closed, close the WeightedSocket + rSocket + .onClose() + .doFinally( + signalType -> { + logger.info( + "RSocket {} from factory {} closed", WeightedSocket.this, factory); + WeightedSocket.this.dispose(); + }) + .subscribe(); + + // When the factory is closed, close the RSocket + factory + .onClose() + .doFinally( + signalType -> { + logger.info("Factory {} closed", factory); + rSocket.dispose(); + }) + .subscribe(); + + // When the WeightedSocket is closed, close the RSocket + WeightedSocket.this + .onClose() + .doFinally( + signalType -> { + logger.info( + "WeightedSocket {} from factory {} closed", + WeightedSocket.this, + factory); + rSocket.dispose(); + }) + .subscribe(); + + /*synchronized (LoadBalancedRSocketMono.this) { + if (activeSockets.size() >= targetAperture) { + quickSlowestRS(); + pendingSockets -= 1; + } + }*/ + rSocketMono.onNext(rSocket); + availability = 1.0; + if (!WeightedSocket.this + .isDisposed()) { // May be already disposed because of retryBackoff delay + activeSockets.add(WeightedSocket.this); + logger.debug( + "Added WeightedSocket {} from factory {} to activeSockets", + WeightedSocket.this, + factory); + } + }); + } + + WeightedSocket(RSocketSupplier factory, Quantile lowerQuantile, Quantile higherQuantile) { + this(factory, lowerQuantile, higherQuantile, DEFAULT_INTER_ARRIVAL_FACTOR); + } + + @Override + public Mono requestResponse(Payload payload) { + return rSocketMono.flatMap( + source -> + Mono.from( + subscriber -> + source + .requestResponse(payload) + .subscribe( + new LatencySubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); + } + + @Override + public Flux requestStream(Payload payload) { + + return rSocketMono.flatMapMany( + source -> + Flux.from( + subscriber -> + source + .requestStream(payload) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); + } + + @Override + public Mono fireAndForget(Payload payload) { + + return rSocketMono.flatMap( + source -> { + return Mono.from( + subscriber -> + source + .fireAndForget(payload) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this))); + }); + } + + @Override + public Mono metadataPush(Payload payload) { + return rSocketMono.flatMap( + source -> { + return Mono.from( + subscriber -> + source + .metadataPush(payload) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this))); + }); + } + + @Override + public Flux requestChannel(Publisher payloads) { + + return rSocketMono.flatMapMany( + source -> + Flux.from( + subscriber -> + source + .requestChannel(payloads) + .subscribe( + new CountingSubscriber<>( + Operators.toCoreSubscriber(subscriber), this)))); + } + + synchronized double getPredictedLatency() { + long now = Clock.now(); + long elapsed = Math.max(now - stamp, 1L); + + double weight; + double prediction = median.estimation(); + + if (prediction == 0.0) { + if (pending == 0) { + weight = 0.0; // first request + } else { + // subsequent requests while we don't have any history + weight = STARTUP_PENALTY + pending; + } + } else if (pending == 0 && elapsed > inactivityFactor * interArrivalTime.value()) { + // if we did't see any data for a while, we decay the prediction by inserting + // artificial 0.0 into the median + median.insert(0.0); + weight = median.estimation(); + } else { + double predicted = prediction * pending; + double instant = instantaneous(now); + + if (predicted < instant) { // NB: (0.0 < 0.0) == false + weight = instant / pending; // NB: pending never equal 0 here + } else { + // we are under the predictions + weight = prediction; + } + } + + return weight; + } + + int getPending() { + return pending; + } + + private synchronized long instantaneous(long now) { + return duration + (now - stamp0) * pending; + } + + private synchronized long incr() { + long now = Clock.now(); + interArrivalTime.insert(now - stamp); + duration += Math.max(0, now - stamp0) * pending; + pending += 1; + stamp = now; + stamp0 = now; + return now; + } + + private synchronized long decr(long timestamp) { + long now = Clock.now(); + duration += Math.max(0, now - stamp0) * pending - (now - timestamp); + pending -= 1; + stamp0 = now; + return now; + } + + private synchronized void observe(double rtt) { + median.insert(rtt); + lowerQuantile.insert(rtt); + higherQuantile.insert(rtt); + } + + @Override + public double availability() { + return availability; + } + + @Override + public void dispose() { + onClose.onComplete(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; + } + + @Override + public String toString() { + return "WeightedSocket(" + + "median=" + + median.estimation() + + " quantile-low=" + + lowerQuantile.estimation() + + " quantile-high=" + + higherQuantile.estimation() + + " inter-arrival=" + + interArrivalTime.value() + + " duration/pending=" + + (pending == 0 ? 0 : (double) duration / pending) + + " pending=" + + pending + + " availability= " + + availability() + + ")->"; + } + + @Override + public double medianLatency() { + return median.estimation(); + } + + @Override + public double lowerQuantileLatency() { + return lowerQuantile.estimation(); + } + + @Override + public double higherQuantileLatency() { + return higherQuantile.estimation(); + } + + @Override + public double interArrivalTime() { + return interArrivalTime.value(); + } + + @Override + public int pending() { + return pending; + } + + @Override + public long lastTimeUsedMillis() { + return stamp0; + } + + /** + * Subscriber wrapper used for request/response interaction model, measure and collect latency + * information. + */ + private class LatencySubscriber implements CoreSubscriber { + private final CoreSubscriber child; + private final WeightedSocket socket; + private final AtomicBoolean done; + private long start; + + LatencySubscriber(CoreSubscriber child, WeightedSocket socket) { + this.child = child; + this.socket = socket; + this.done = new AtomicBoolean(false); + } + + @Override + public Context currentContext() { + return child.currentContext(); + } + + @Override + public void onSubscribe(Subscription s) { + start = incr(); + child.onSubscribe( + new Subscription() { + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + if (done.compareAndSet(false, true)) { + s.cancel(); + decr(start); + } + } + }); + } + + @Override + public void onNext(U u) { + child.onNext(u); + } + + @Override + public void onError(Throwable t) { + if (done.compareAndSet(false, true)) { + child.onError(t); + long now = decr(start); + if (t instanceof TransportException || t instanceof ClosedChannelException) { + socket.dispose(); + } else if (t instanceof TimeoutException) { + observe(now - start); + } + } + } + + @Override + public void onComplete() { + if (done.compareAndSet(false, true)) { + long now = decr(start); + observe(now - start); + child.onComplete(); + } + } + } + + /** + * Subscriber wrapper used for stream like interaction model, it only counts the number of + * active streams + */ + private class CountingSubscriber implements CoreSubscriber { + private final CoreSubscriber child; + private final WeightedSocket socket; + + CountingSubscriber(CoreSubscriber child, WeightedSocket socket) { + this.child = child; + this.socket = socket; + } + + @Override + public Context currentContext() { + return child.currentContext(); + } + + @Override + public void onSubscribe(Subscription s) { + socket.pendingStreams.incrementAndGet(); + child.onSubscribe(s); + } + + @Override + public void onNext(U u) { + child.onNext(u); + } + + @Override + public void onError(Throwable t) { + socket.pendingStreams.decrementAndGet(); + child.onError(t); + if (t instanceof TransportException || t instanceof ClosedChannelException) { + logger.debug("Disposing {} from activeSockets because of error {}", socket, t); + socket.dispose(); + } + } + + @Override + public void onComplete() { + socket.pendingStreams.decrementAndGet(); + child.onComplete(); + } + } + } +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java new file mode 100644 index 000000000..0cb35d180 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/LoadBalancerSocketMetrics.java @@ -0,0 +1,67 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client; + +import io.rsocket.Availability; + +@Deprecated +/** A contract for the metrics managed by {@link LoadBalancedRSocketMono} per socket. */ +public interface LoadBalancerSocketMetrics extends Availability { + + /** + * Median value of latency as per last calculation. This is not calculated per invocation. + * + * @return Median latency. + */ + double medianLatency(); + + /** + * Lower quantile of latency as per last calculation. This is not calculated per invocation. + * + * @return Median latency. + */ + double lowerQuantileLatency(); + + /** + * Higher quantile value of latency as per last calculation. This is not calculated per + * invocation. + * + * @return Median latency. + */ + double higherQuantileLatency(); + + /** + * An exponentially weighted moving average value of the time between two requests. + * + * @return Inter arrival time. + */ + double interArrivalTime(); + + /** + * Number of pending requests at this moment. + * + * @return Number of pending requests at this moment. + */ + int pending(); + + /** + * Last time this socket was used i.e. either a request was sent or a response was received. + * + * @return Last time used in millis since epoch. + */ + long lastTimeUsedMillis(); +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java new file mode 100644 index 000000000..295d25d75 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/NoAvailableRSocketException.java @@ -0,0 +1,41 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client; + +@Deprecated +/** An exception that indicates that no RSocket was available. */ +public final class NoAvailableRSocketException extends Exception { + + /** + * The single instance of this type. Note that it is initialized without any stack trace. + */ + public static final NoAvailableRSocketException INSTANCE; + + private static final long serialVersionUID = -2785312562743351184L; + + static { + NoAvailableRSocketException exception = new NoAvailableRSocketException(); + exception.setStackTrace( + new StackTraceElement[] { + new StackTraceElement(exception.getClass().getName(), "", null, -1) + }); + + INSTANCE = exception; + } + + private NoAvailableRSocketException() {}; +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java new file mode 100644 index 000000000..8249083ad --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/RSocketSupplierPool.java @@ -0,0 +1,197 @@ +package io.rsocket.client; + +import io.rsocket.Closeable; +import io.rsocket.client.filter.RSocketSupplier; +import java.time.Duration; +import java.util.*; +import java.util.concurrent.ThreadLocalRandom; +import java.util.function.Consumer; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +@Deprecated +public class RSocketSupplierPool + implements Supplier>, Consumer, Closeable { + private static final Logger logger = LoggerFactory.getLogger(RSocketSupplierPool.class); + private static final int EFFORT = 5; + + private final ArrayList factoryPool; + private final ArrayList leasedSuppliers; + + private final MonoProcessor onClose; + + public RSocketSupplierPool(Publisher> publisher) { + this.onClose = MonoProcessor.create(); + this.factoryPool = new ArrayList<>(); + this.leasedSuppliers = new ArrayList<>(); + + Disposable disposable = + Flux.from(publisher) + .doOnNext(this::handleNewFactories) + .onErrorResume( + t -> { + logger.error("error streaming RSocketSuppliers", t); + return Mono.delay(Duration.ofSeconds(10)).then(Mono.error(t)); + }) + .subscribe(); + + onClose.doFinally(s -> disposable.dispose()).subscribe(); + } + + private synchronized void handleNewFactories(Collection newFactories) { + Set current = new HashSet<>(factoryPool.size() + leasedSuppliers.size()); + current.addAll(factoryPool); + current.addAll(leasedSuppliers); + + Set removed = new HashSet<>(current); + removed.removeAll(newFactories); + + Set added = new HashSet<>(newFactories); + added.removeAll(current); + + boolean changed = false; + Iterator it0 = leasedSuppliers.iterator(); + while (it0.hasNext()) { + RSocketSupplier supplier = it0.next(); + if (removed.contains(supplier)) { + it0.remove(); + try { + changed = true; + supplier.dispose(); + } catch (Exception e) { + logger.warn("Exception while closing a RSocket", e); + } + } + } + + Iterator it1 = factoryPool.iterator(); + while (it1.hasNext()) { + RSocketSupplier supplier = it1.next(); + if (removed.contains(supplier)) { + it1.remove(); + try { + changed = true; + supplier.dispose(); + } catch (Exception e) { + logger.warn("Exception while closing a RSocket", e); + } + } + } + + factoryPool.addAll(added); + if (!added.isEmpty()) { + changed = true; + } + + if (changed && logger.isDebugEnabled()) { + StringBuilder msgBuilder = new StringBuilder(); + msgBuilder + .append("\nUpdated active factories (size: ") + .append(factoryPool.size()) + .append(")\n"); + for (RSocketSupplier f : factoryPool) { + msgBuilder.append(" + ").append(f).append('\n'); + } + msgBuilder.append("Active sockets:\n"); + for (RSocketSupplier socket : leasedSuppliers) { + msgBuilder.append(" + ").append(socket).append('\n'); + } + logger.debug(msgBuilder.toString()); + } + } + + @Override + public synchronized void accept(RSocketSupplier rSocketSupplier) { + boolean contained = leasedSuppliers.remove(rSocketSupplier); + if (contained + && !rSocketSupplier + .isDisposed()) { // only added leasedSupplier back to factoryPool if it's still there + factoryPool.add(rSocketSupplier); + } + } + + @Override + public synchronized Optional get() { + Optional optional = Optional.empty(); + int poolSize = factoryPool.size(); + if (poolSize == 1) { + RSocketSupplier rSocketSupplier = factoryPool.get(0); + if (rSocketSupplier.availability() > 0.0) { + factoryPool.remove(0); + leasedSuppliers.add(rSocketSupplier); + logger.debug("Added {} to leasedSuppliers", rSocketSupplier); + optional = Optional.of(rSocketSupplier); + } + } else if (poolSize > 1) { + Random rng = ThreadLocalRandom.current(); + int size = factoryPool.size(); + RSocketSupplier factory0 = null; + RSocketSupplier factory1 = null; + int i0 = 0; + int i1 = 0; + for (int i = 0; i < EFFORT; i++) { + i0 = rng.nextInt(size); + i1 = rng.nextInt(size - 1); + if (i1 >= i0) { + i1++; + } + factory0 = factoryPool.get(i0); + factory1 = factoryPool.get(i1); + if (factory0.availability() > 0.0 && factory1.availability() > 0.0) { + break; + } + } + if (factory0.availability() > factory1.availability()) { + factoryPool.remove(i0); + leasedSuppliers.add(factory0); + logger.debug("Added {} to leasedSuppliers", factory0); + optional = Optional.of(factory0); + } else { + factoryPool.remove(i1); + leasedSuppliers.add(factory1); + logger.debug("Added {} to leasedSuppliers", factory1); + optional = Optional.of(factory1); + } + } + + return optional; + } + + @Override + public Mono onClose() { + return onClose; + } + + @Override + public void dispose() { + if (!onClose.isDisposed()) { + onClose.onComplete(); + + close(factoryPool); + close(leasedSuppliers); + } + } + + private void close(Collection suppliers) { + for (RSocketSupplier supplier : suppliers) { + try { + supplier.dispose(); + } catch (Throwable t) { + } + } + } + + public synchronized int poolSize() { + return factoryPool.size(); + } + + public synchronized boolean isPoolEmpty() { + return factoryPool.isEmpty(); + } +} diff --git a/src/main/java/io/reactivesocket/exceptions/ConnectionException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java similarity index 57% rename from src/main/java/io/reactivesocket/exceptions/ConnectionException.java rename to rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java index 0fe0aa7c5..a32ac2224 100644 --- a/src/main/java/io/reactivesocket/exceptions/ConnectionException.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/TimeoutException.java @@ -1,11 +1,11 @@ -/** - * Copyright 2015 Netflix, Inc. +/* + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket.exceptions; -public class ConnectionException extends Throwable implements Retryable { - public ConnectionException(String message) { - super(message); - } +package io.rsocket.client; - @Override - public synchronized Throwable fillInStackTrace() { - return this; - } +@Deprecated +public final class TimeoutException extends Exception { + + public static final TimeoutException INSTANCE = new TimeoutException(); + + private static final long serialVersionUID = -3094321310317812063L; + + private TimeoutException() {} } diff --git a/src/main/java/io/reactivesocket/exceptions/TransportException.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java similarity index 60% rename from src/main/java/io/reactivesocket/exceptions/TransportException.java rename to rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java index e4e4029a0..4779c6d4d 100644 --- a/src/main/java/io/reactivesocket/exceptions/TransportException.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/TransportException.java @@ -1,11 +1,11 @@ -/** - * Copyright 2015 Netflix, Inc. +/* + * Copyright 2015-2018 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,16 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket.exceptions; -public class TransportException extends Throwable { - public TransportException(Throwable t) { - super(t); - } +package io.rsocket.client; - @Override - public synchronized Throwable fillInStackTrace() { - return this; - } +@Deprecated +public final class TransportException extends Throwable { + private static final long serialVersionUID = -3339846338318701123L; + + public TransportException(Throwable t) { + super(t); + } } diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java new file mode 100644 index 000000000..beb424797 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/BackupRequestSocket.java @@ -0,0 +1,231 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client.filter; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.stat.FrugalQuantile; +import io.rsocket.stat.Quantile; +import io.rsocket.util.Clock; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +@Deprecated +public class BackupRequestSocket implements RSocket { + private final ScheduledExecutorService executor; + private final RSocket child; + private final Quantile q; + + public BackupRequestSocket(RSocket child, double quantile, ScheduledExecutorService executor) { + this.child = child; + this.executor = executor; + q = new FrugalQuantile(quantile); + } + + public BackupRequestSocket(RSocket child, double quantile) { + this(child, quantile, Executors.newScheduledThreadPool(2)); + } + + public BackupRequestSocket(RSocket child) { + this(child, 0.99); + } + + @Override + public Mono fireAndForget(Payload payload) { + return child.fireAndForget(payload); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.from( + subscriber -> { + Subscriber oneSubscriber = new OneSubscriber<>(subscriber); + Subscriber backupRequest = + new FirstRequestSubscriber(oneSubscriber, () -> child.requestResponse(payload)); + child.requestResponse(payload).subscribe(backupRequest); + }); + } + + @Override + public Flux requestStream(Payload payload) { + return child.requestStream(payload); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return child.requestChannel(payloads); + } + + @Override + public Mono metadataPush(Payload payload) { + return child.metadataPush(payload); + } + + @Override + public double availability() { + return child.availability(); + } + + @Override + public void dispose() { + child.dispose(); + } + + @Override + public boolean isDisposed() { + return child.isDisposed(); + } + + @Override + public Mono onClose() { + return child.onClose(); + } + + @Override + public String toString() { + return "BackupRequest(q=" + q + ")->" + child; + } + + private static class OneSubscriber implements Subscriber { + private final Subscriber subscriber; + private final AtomicBoolean firstEvent; + private final AtomicBoolean firstTerminal; + + private OneSubscriber(Subscriber subscriber) { + this.subscriber = subscriber; + this.firstEvent = new AtomicBoolean(false); + this.firstTerminal = new AtomicBoolean(false); + } + + @Override + public void onSubscribe(Subscription s) { + subscriber.onSubscribe(s); + } + + @Override + public void onNext(T t) { + if (firstEvent.compareAndSet(false, true)) { + subscriber.onNext(t); + } + } + + @Override + public void onError(Throwable t) { + if (firstTerminal.compareAndSet(false, true)) { + subscriber.onError(t); + } + } + + @Override + public void onComplete() { + if (firstTerminal.compareAndSet(false, true)) { + subscriber.onComplete(); + } + } + } + + private class FirstRequestSubscriber implements Subscriber { + private final Subscriber oneSubscriber; + private final Supplier> action; + private long start; + private ScheduledFuture future; + + private FirstRequestSubscriber( + Subscriber oneSubscriber, Supplier> action) { + this.oneSubscriber = oneSubscriber; + this.action = action; + } + + @Override + public void onSubscribe(Subscription s) { + start = Clock.now(); + if (q.estimation() > 0) { + future = + executor.schedule( + () -> action.get().subscribe(new BackupRequestSubscriber<>(oneSubscriber, s)), + (long) q.estimation(), + TimeUnit.MICROSECONDS); + } + oneSubscriber.onSubscribe(s); + } + + @Override + public void onNext(Payload t) { + if (future != null) { + future.cancel(true); + } + oneSubscriber.onNext(t); + long latency = Clock.now() - start; + q.insert(latency); + } + + @Override + public void onError(Throwable t) { + oneSubscriber.onError(t); + } + + @Override + public void onComplete() { + oneSubscriber.onComplete(); + } + } + + private class BackupRequestSubscriber implements Subscriber { + private final Subscriber oneSubscriber; + private final Subscription firstRequestSubscription; + private long start; + + private BackupRequestSubscriber( + Subscriber oneSubscriber, Subscription firstRequestSubscription) { + this.oneSubscriber = oneSubscriber; + this.firstRequestSubscription = firstRequestSubscription; + } + + @Override + public void onSubscribe(Subscription s) { + start = Clock.now(); + s.request(1); + } + + @Override + public void onNext(T t) { + firstRequestSubscription.cancel(); + oneSubscriber.onNext(t); + long latency = Clock.now() - start; + q.insert(latency); + } + + @Override + public void onError(Throwable t) { + oneSubscriber.onError(t); + } + + @Override + public void onComplete() { + oneSubscriber.onComplete(); + } + } +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java new file mode 100644 index 000000000..aaf9f71e6 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSocketSupplier.java @@ -0,0 +1,162 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client.filter; + +import io.rsocket.Availability; +import io.rsocket.Closeable; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.stat.Ewma; +import io.rsocket.util.Clock; +import io.rsocket.util.RSocketProxy; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.MonoProcessor; + +/** */ +@Deprecated +public class RSocketSupplier implements Availability, Supplier>, Closeable { + + private static final double EPSILON = 1e-4; + + private Supplier> rSocketSupplier; + + private final MonoProcessor onClose; + + private final long tau; + private long stamp; + private final Ewma errorPercentage; + + public RSocketSupplier(Supplier> rSocketSupplier, long halfLife, TimeUnit unit) { + this.rSocketSupplier = rSocketSupplier; + this.tau = Clock.unit().convert((long) (halfLife / Math.log(2)), unit); + this.stamp = Clock.now(); + this.errorPercentage = new Ewma(halfLife, unit, 1.0); + this.onClose = MonoProcessor.create(); + } + + public RSocketSupplier(Supplier> rSocketSupplier) { + this(rSocketSupplier, 5, TimeUnit.SECONDS); + } + + @Override + public double availability() { + double e = errorPercentage.value(); + if (Clock.now() - stamp > tau) { + // If the window is expired artificially increase the availability + double a = Math.min(1.0, e + 0.5); + errorPercentage.reset(a); + } + if (e < EPSILON) { + e = 0.0; + } else if (1.0 - EPSILON < e) { + e = 1.0; + } + + return e; + } + + private synchronized void updateErrorPercentage(double value) { + errorPercentage.insert(value); + stamp = Clock.now(); + } + + @Override + public Mono get() { + return rSocketSupplier + .get() + .doOnNext(o -> updateErrorPercentage(1.0)) + .doOnError(t -> updateErrorPercentage(0.0)) + .map(AvailabilityAwareRSocketProxy::new); + } + + @Override + public void dispose() { + onClose.onComplete(); + } + + @Override + public boolean isDisposed() { + return onClose.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; + } + + private class AvailabilityAwareRSocketProxy extends RSocketProxy { + public AvailabilityAwareRSocketProxy(RSocket source) { + super(source); + + onClose.doFinally(signalType -> source.dispose()).subscribe(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return source + .fireAndForget(payload) + .doOnError(t -> errorPercentage.insert(0.0)) + .doOnSuccess(v -> updateErrorPercentage(1.0)); + } + + @Override + public Mono requestResponse(Payload payload) { + return source + .requestResponse(payload) + .doOnError(t -> errorPercentage.insert(0.0)) + .doOnSuccess(p -> updateErrorPercentage(1.0)); + } + + @Override + public Flux requestStream(Payload payload) { + return source + .requestStream(payload) + .doOnError(th -> errorPercentage.insert(0.0)) + .doOnComplete(() -> updateErrorPercentage(1.0)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return source + .requestChannel(payloads) + .doOnError(th -> errorPercentage.insert(0.0)) + .doOnComplete(() -> updateErrorPercentage(1.0)); + } + + @Override + public Mono metadataPush(Payload payload) { + return source + .metadataPush(payload) + .doOnError(t -> errorPercentage.insert(0.0)) + .doOnSuccess(v -> updateErrorPercentage(1.0)); + } + + @Override + public double availability() { + // If the window is expired set success and failure to zero and return + // the child availability + if (Clock.now() - stamp > tau) { + updateErrorPercentage(1.0); + } + return source.availability() * errorPercentage.value(); + } + } +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java new file mode 100644 index 000000000..89ff74143 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/RSockets.java @@ -0,0 +1,162 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client.filter; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.RSocketProxy; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +@Deprecated +public final class RSockets { + + private RSockets() { + // No Instances. + } + + /** + * Provides a mapping function to wrap a {@code RSocket} such that all requests will timeout, if + * not completed after the specified {@code timeout}. + * + * @param timeout timeout duration. + * @return Function to transform any socket into a timeout socket. + */ + public static Function timeout(Duration timeout) { + return source -> + new RSocketProxy(source) { + @Override + public Mono fireAndForget(Payload payload) { + return source.fireAndForget(payload).timeout(timeout); + } + + @Override + public Mono requestResponse(Payload payload) { + return source.requestResponse(payload).timeout(timeout); + } + + @Override + public Flux requestStream(Payload payload) { + return source.requestStream(payload).timeout(timeout); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return source.requestChannel(payloads).timeout(timeout); + } + + @Override + public Mono metadataPush(Payload payload) { + return source.metadataPush(payload).timeout(timeout); + } + }; + } + + /** + * Provides a mapping function to wrap a {@code RSocket} such that a call to {@link + * RSocket#dispose()} does not cancel all pending requests. Instead, it will wait for all pending + * requests to finish and then close the socket. + * + * @return Function to transform any socket into a safe closing socket. + */ + public static Function safeClose() { + return source -> + new RSocketProxy(source) { + final AtomicInteger count = new AtomicInteger(); + final AtomicBoolean closed = new AtomicBoolean(); + + @Override + public Mono fireAndForget(Payload payload) { + return source + .fireAndForget(payload) + .doOnSubscribe(s -> count.incrementAndGet()) + .doFinally( + signalType -> { + if (count.decrementAndGet() == 0 && closed.get()) { + source.dispose(); + } + }); + } + + @Override + public Mono requestResponse(Payload payload) { + return source + .requestResponse(payload) + .doOnSubscribe(s -> count.incrementAndGet()) + .doFinally( + signalType -> { + if (count.decrementAndGet() == 0 && closed.get()) { + source.dispose(); + } + }); + } + + @Override + public Flux requestStream(Payload payload) { + return source + .requestStream(payload) + .doOnSubscribe(s -> count.incrementAndGet()) + .doFinally( + signalType -> { + if (count.decrementAndGet() == 0 && closed.get()) { + source.dispose(); + } + }); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return source + .requestChannel(payloads) + .doOnSubscribe(s -> count.incrementAndGet()) + .doFinally( + signalType -> { + if (count.decrementAndGet() == 0 && closed.get()) { + source.dispose(); + } + }); + } + + @Override + public Mono metadataPush(Payload payload) { + return source + .metadataPush(payload) + .doOnSubscribe(s -> count.incrementAndGet()) + .doFinally( + signalType -> { + if (count.decrementAndGet() == 0 && closed.get()) { + source.dispose(); + } + }); + } + + @Override + public void dispose() { + if (closed.compareAndSet(false, true)) { + if (count.get() == 0) { + source.dispose(); + } + } + } + }; + } +} diff --git a/src/main/java/io/reactivesocket/Payload.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java similarity index 69% rename from src/main/java/io/reactivesocket/Payload.java rename to rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java index b69807a8b..55ce5646c 100644 --- a/src/main/java/io/reactivesocket/Payload.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/filter/package-info.java @@ -1,11 +1,11 @@ -/** - * Copyright 2015 Netflix, Inc. +/* + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,12 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket; -import java.nio.ByteBuffer; +@NonNullApi +package io.rsocket.client.filter; -public interface Payload -{ - ByteBuffer getData(); - ByteBuffer getMetadata(); -} +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java b/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java new file mode 100644 index 000000000..ec21dee96 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/client/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +package io.rsocket.client; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java new file mode 100644 index 000000000..3968ec0a4 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Ewma.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.stat; + +import io.rsocket.util.Clock; +import java.util.concurrent.TimeUnit; + +/** + * Compute the exponential weighted moving average of a series of values. The time at which you + * insert the value into `Ewma` is used to compute a weight (recent points are weighted higher). The + * parameter for defining the convergence speed (like most decay process) is the half-life. + * + *

e.g. with a half-life of 10 unit, if you insert 100 at t=0 and 200 at t=10 the ewma will be + * equal to (200 - 100)/2 = 150 (half of the distance between the new and the old value) + */ +@Deprecated +public class Ewma { + private final long tau; + private volatile long stamp; + private volatile double ewma; + + public Ewma(long halfLife, TimeUnit unit, double initialValue) { + this.tau = Clock.unit().convert((long) (halfLife / Math.log(2)), unit); + stamp = 0L; + ewma = initialValue; + } + + public synchronized void insert(double x) { + long now = Clock.now(); + double elapsed = Math.max(0, now - stamp); + stamp = now; + + double w = Math.exp(-elapsed / tau); + ewma = w * ewma + (1.0 - w) * x; + } + + public synchronized void reset(double value) { + stamp = 0L; + ewma = value; + } + + public double value() { + return ewma; + } + + @Override + public String toString() { + return "Ewma(value=" + ewma + ", age=" + (Clock.now() - stamp) + ")"; + } +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java new file mode 100644 index 000000000..99c12e801 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/FrugalQuantile.java @@ -0,0 +1,108 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.stat; + +import java.util.Random; + +/** + * Reference: Ma, Qiang, S. Muthukrishnan, and Mark Sandler. "Frugal Streaming for Estimating + * Quantiles." Space-Efficient Data Structures, Streams, and Algorithms. Springer Berlin Heidelberg, + * 2013. 77-96. + * + *

More info: http://blog.aggregateknowledge.com/2013/09/16/sketch-of-the-day-frugal-streaming/ + */ +@Deprecated +public class FrugalQuantile implements Quantile { + private final double increment; + private double quantile; + private Random rng; + + volatile double estimate; + int step; + int sign; + + public FrugalQuantile(double quantile, double increment, Random rng) { + this.increment = increment; + this.quantile = quantile; + this.estimate = 0.0; + this.step = 1; + this.sign = 0; + this.rng = rng; + } + + public FrugalQuantile(double quantile) { + this(quantile, 1.0, new Random()); + } + + public double estimation() { + return estimate; + } + + @Override + public synchronized void insert(double x) { + if (sign == 0) { + estimate = x; + sign = 1; + return; + } + + if (x > estimate && rng.nextDouble() > (1 - quantile)) { + step += sign * increment; + + if (step > 0) { + estimate += step; + } else { + estimate += 1; + } + + if (estimate > x) { + step += (x - estimate); + estimate = x; + } + + if (sign < 0) { + step = 1; + } + + sign = 1; + } else if (x < estimate && rng.nextDouble() > quantile) { + step -= sign * increment; + + if (step > 0) { + estimate -= step; + } else { + estimate--; + } + + if (estimate < x) { + step += (estimate - x); + estimate = x; + } + + if (sign > 0) { + step = 1; + } + + sign = -1; + } + } + + @Override + public String toString() { + return "FrugalQuantile(q=" + quantile + ", v=" + estimate + ")"; + } +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java new file mode 100644 index 000000000..00dd69de9 --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Median.java @@ -0,0 +1,79 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.stat; + +/** This implementation gives better results because it considers more data-point. */ +@Deprecated +public class Median extends FrugalQuantile { + public Median() { + super(0.5, 1.0, null); + } + + @Override + public synchronized void insert(double x) { + if (sign == 0) { + estimate = x; + sign = 1; + return; + } + + if (x > estimate) { + step += sign; + + if (step > 0) { + estimate += step; + } else { + estimate += 1; + } + + if (estimate > x) { + step += (x - estimate); + estimate = x; + } + + if (sign < 0) { + step = 1; + } + + sign = 1; + } else if (x < estimate) { + step -= sign; + + if (step > 0) { + estimate -= step; + } else { + estimate--; + } + + if (estimate < x) { + step += (estimate - x); + estimate = x; + } + + if (sign > 0) { + step = 1; + } + + sign = -1; + } + } + + @Override + public String toString() { + return "Median(v=" + estimate + ")"; + } +} diff --git a/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java new file mode 100644 index 000000000..aa3667e8f --- /dev/null +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/Quantile.java @@ -0,0 +1,29 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.stat; + +@Deprecated +public interface Quantile { + /** @return the estimation of the current value of the quantile */ + double estimation(); + + /** + * Insert a data point `x` in the quantile estimator. + * + * @param x the data point to add. + */ + void insert(double x); +} diff --git a/src/main/java/io/reactivesocket/exceptions/Retryable.java b/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java similarity index 73% rename from src/main/java/io/reactivesocket/exceptions/Retryable.java rename to rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java index aaa2edde8..cfb071175 100644 --- a/src/main/java/io/reactivesocket/exceptions/Retryable.java +++ b/rsocket-load-balancer/src/main/java/io/rsocket/stat/package-info.java @@ -1,11 +1,11 @@ -/** - * Copyright 2015 Netflix, Inc. +/* + * Copyright 2015-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -13,9 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.reactivesocket.exceptions; -/** - * Marker interface only - */ -public interface Retryable {} +@NonNullApi +package io.rsocket.stat; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java new file mode 100644 index 000000000..52bf89558 --- /dev/null +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/LoadBalancedRSocketMonoTest.java @@ -0,0 +1,143 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.client.filter.RSocketSupplier; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class LoadBalancedRSocketMonoTest { + + @Test + @Timeout(10_000L) + public void testNeverSelectFailingFactories() throws InterruptedException { + TestingRSocket socket = new TestingRSocket(Function.identity()); + RSocketSupplier failing = failingClient(); + RSocketSupplier succeeding = succeedingFactory(socket); + List factories = Arrays.asList(failing, succeeding); + + testBalancer(factories); + } + + @Test + @Timeout(10_000L) + public void testNeverSelectFailingSocket() throws InterruptedException { + TestingRSocket socket = new TestingRSocket(Function.identity()); + TestingRSocket failingSocket = + new TestingRSocket(Function.identity()) { + @Override + public Mono requestResponse(Payload payload) { + return Mono.error(new RuntimeException("You shouldn't be here")); + } + + @Override + public double availability() { + return 0.0; + } + }; + + RSocketSupplier failing = succeedingFactory(failingSocket); + RSocketSupplier succeeding = succeedingFactory(socket); + List clients = Arrays.asList(failing, succeeding); + + testBalancer(clients); + } + + @Test + @Timeout(10_000L) + @Disabled + public void testRefreshesSocketsOnSelectBeforeReturningFailedAfterNewFactoriesDelivered() { + TestingRSocket socket = new TestingRSocket(Function.identity()); + + CompletableFuture laterSupplier = new CompletableFuture<>(); + Flux> factories = + Flux.create( + s -> { + s.next(Collections.emptyList()); + + laterSupplier.handle( + (RSocketSupplier result, Throwable t) -> { + s.next(Collections.singletonList(result)); + return null; + }); + }); + + LoadBalancedRSocketMono balancer = LoadBalancedRSocketMono.create(factories); + + assertThat(balancer.availability()).isZero(); + + laterSupplier.complete(succeedingFactory(socket)); + balancer.rSocketMono.block(); + + assertThat(balancer.availability()).isEqualTo(1.0); + } + + private void testBalancer(List factories) throws InterruptedException { + Publisher> src = + s -> { + s.onNext(factories); + s.onComplete(); + }; + + LoadBalancedRSocketMono balancer = LoadBalancedRSocketMono.create(src); + + while (balancer.availability() == 0.0) { + Thread.sleep(1); + } + + Flux.range(0, 100).flatMap(i -> balancer).blockLast(); + } + + private static RSocketSupplier succeedingFactory(RSocket socket) { + RSocketSupplier mock = Mockito.mock(RSocketSupplier.class); + + Mockito.when(mock.availability()).thenReturn(1.0); + Mockito.when(mock.get()).thenReturn(Mono.just(socket)); + Mockito.when(mock.onClose()).thenReturn(Mono.never()); + + return mock; + } + + private static RSocketSupplier failingClient() { + RSocketSupplier mock = Mockito.mock(RSocketSupplier.class); + + Mockito.when(mock.availability()).thenReturn(0.0); + Mockito.when(mock.get()) + .thenAnswer( + a -> { + fail(); + return null; + }); + + return mock; + } +} diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java new file mode 100644 index 000000000..9e1982465 --- /dev/null +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/RSocketSupplierTest.java @@ -0,0 +1,148 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.verify; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.client.filter.RSocketSupplier; +import io.rsocket.test.TestSubscriber; +import io.rsocket.util.EmptyPayload; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.publisher.Mono; + +public class RSocketSupplierTest { + + @Test + public void testError() throws InterruptedException { + testRSocket( + (latch, socket) -> { + assertThat(socket.availability()).isEqualTo(1.0); + Publisher payloadPublisher = socket.requestResponse(EmptyPayload.INSTANCE); + + Subscriber subscriber = TestSubscriber.create(); + payloadPublisher.subscribe(subscriber); + + verify(subscriber).onComplete(); + + double good = socket.availability(); + + try { + Thread.sleep(100); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + subscriber = TestSubscriber.create(); + payloadPublisher.subscribe(subscriber); + verify(subscriber).onError(any(RuntimeException.class)); + double bad = socket.availability(); + assertThat(good > bad).isTrue(); + latch.countDown(); + }); + } + + @Test + public void testWidowReset() throws InterruptedException { + testRSocket( + (latch, socket) -> { + assertThat(socket.availability()).isEqualTo(1.0); + Publisher payloadPublisher = socket.requestResponse(EmptyPayload.INSTANCE); + + Subscriber subscriber = TestSubscriber.create(); + payloadPublisher.subscribe(subscriber); + + verify(subscriber).onComplete(); + double good = socket.availability(); + + subscriber = TestSubscriber.create(); + payloadPublisher.subscribe(subscriber); + + verify(subscriber).onError(any(RuntimeException.class)); + double bad = socket.availability(); + assertThat(good > bad).isTrue(); + + try { + Thread.sleep(200); + } catch (InterruptedException e) { + e.printStackTrace(); + } + + double reset = socket.availability(); + assertThat(reset > bad).isTrue(); + latch.countDown(); + }); + } + + private void testRSocket(BiConsumer f) throws InterruptedException { + AtomicInteger count = new AtomicInteger(0); + TestingRSocket socket = + new TestingRSocket( + input -> { + if (count.getAndIncrement() < 1) { + return EmptyPayload.INSTANCE; + } else { + throw new RuntimeException(); + } + }); + + RSocketSupplier factory = Mockito.mock(RSocketSupplier.class); + + Mockito.when(factory.availability()).thenReturn(1.0); + Mockito.when(factory.get()).thenReturn(Mono.just(socket)); + + RSocketSupplier failureFactory = new RSocketSupplier(factory, 100, TimeUnit.MILLISECONDS); + + CountDownLatch latch = new CountDownLatch(1); + failureFactory + .get() + .subscribe( + new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(1); + } + + @Override + public void onNext(RSocket socket) { + f.accept(latch, socket); + } + + @Override + public void onError(Throwable t) { + fail(); + } + + @Override + public void onComplete() {} + }); + + latch.await(30, TimeUnit.SECONDS); + } +} diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java new file mode 100644 index 000000000..2827c8ed4 --- /dev/null +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/TestingRSocket.java @@ -0,0 +1,145 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiFunction; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import reactor.core.Scannable; +import reactor.core.publisher.*; + +public class TestingRSocket implements RSocket { + + private final AtomicInteger count; + private final Sinks.Empty onClose = Sinks.empty(); + private final BiFunction, Payload, Boolean> eachPayloadHandler; + + public TestingRSocket(Function responder) { + this( + (subscriber, payload) -> { + subscriber.onNext(responder.apply(payload)); + return true; + }); + } + + public TestingRSocket( + BiFunction, Payload, Boolean> eachPayloadHandler) { + this.eachPayloadHandler = eachPayloadHandler; + this.count = new AtomicInteger(0); + } + + public int countMessageReceived() { + return count.get(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return Mono.empty(); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.from( + subscriber -> + subscriber.onSubscribe( + new Subscription() { + boolean cancelled; + + @Override + public void request(long n) { + if (cancelled) { + return; + } + try { + count.incrementAndGet(); + if (eachPayloadHandler.apply(subscriber, payload)) { + subscriber.onComplete(); + } + } catch (Throwable t) { + subscriber.onError(t); + } + } + + @Override + public void cancel() {} + })); + } + + @Override + public Flux requestStream(Payload payload) { + return requestResponse(payload).flux(); + } + + @Override + public Flux requestChannel(Publisher inputs) { + return Flux.from( + subscriber -> + inputs.subscribe( + new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + subscriber.onSubscribe(s); + } + + @Override + public void onNext(Payload input) { + eachPayloadHandler.apply(subscriber, input); + } + + @Override + public void onError(Throwable t) { + subscriber.onError(t); + } + + @Override + public void onComplete() { + subscriber.onComplete(); + } + })); + } + + @Override + public Mono metadataPush(Payload payload) { + return fireAndForget(payload); + } + + @Override + public double availability() { + return 1.0; + } + + @Override + public void dispose() { + onClose.tryEmitEmpty(); + } + + @Override + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); + } + + @Override + public Mono onClose() { + return onClose.asMono(); + } +} diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java new file mode 100644 index 000000000..b8866b1f6 --- /dev/null +++ b/rsocket-load-balancer/src/test/java/io/rsocket/client/TimeoutClientTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.client; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.client.filter.RSockets; +import io.rsocket.util.EmptyPayload; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +public class TimeoutClientTest { + @Test + public void testTimeoutSocket() { + TestingRSocket socket = new TestingRSocket((subscriber, payload) -> false); + RSocket timeout = RSockets.timeout(Duration.ofMillis(50)).apply(socket); + + timeout + .requestResponse(EmptyPayload.INSTANCE) + .subscribe( + new Subscriber() { + @Override + public void onSubscribe(Subscription s) { + s.request(1); + } + + @Override + public void onNext(Payload payload) { + throw new AssertionError("onNext invoked when not expected."); + } + + @Override + public void onError(Throwable t) { + assertThat(t) + .describedAs("Unexpected exception in onError") + .isInstanceOf(TimeoutException.class); + } + + @Override + public void onComplete() { + throw new AssertionError("onComplete invoked when not expected."); + } + }); + } +} diff --git a/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java b/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java new file mode 100644 index 000000000..b214a725e --- /dev/null +++ b/rsocket-load-balancer/src/test/java/io/rsocket/stat/MedianTest.java @@ -0,0 +1,66 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.stat; + +import java.util.Arrays; +import java.util.Random; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.Test; + +public class MedianTest { + private double errorSum = 0; + private double maxError = 0; + private double minError = 0; + + @Test + public void testMedian() { + Random rng = new Random("Repeatable tests".hashCode()); + int n = 1; + for (int i = 0; i < n; i++) { + testMedian(rng); + } + System.out.println( + "Error avg = " + (errorSum / n) + " in range [" + minError + ", " + maxError + "]"); + } + + /** Test Median estimation with normal random data */ + private void testMedian(Random rng) { + int n = 100 * 1024; + int range = Integer.MAX_VALUE >> 16; + Median m = new Median(); + + int[] data = new int[n]; + for (int i = 0; i < data.length; i++) { + int x = Math.max(0, range / 2 + (int) (range / 5 * rng.nextGaussian())); + data[i] = x; + m.insert(x); + } + Arrays.sort(data); + + int expected = data[data.length / 2]; + double estimation = m.estimation(); + double error = Math.abs(expected - estimation) / expected; + + errorSum += error; + maxError = Math.max(maxError, error); + minError = Math.min(minError, error); + + Assertions.assertThat(error < 0.02) + .describedAs("p50=" + estimation + ", real=" + expected + ", error=" + error) + .isTrue(); + } +} diff --git a/rsocket-load-balancer/src/test/resources/logback-test.xml b/rsocket-load-balancer/src/test/resources/logback-test.xml new file mode 100644 index 000000000..13e65b37d --- /dev/null +++ b/rsocket-load-balancer/src/test/resources/logback-test.xml @@ -0,0 +1,33 @@ + + + + + + + + %d{dd MMM yyyy HH:mm:ss,SSS} %5p [%t] %c{1} - %m%n + + + + + + + + + + + diff --git a/rsocket-micrometer/build.gradle b/rsocket-micrometer/build.gradle new file mode 100644 index 000000000..debf02f34 --- /dev/null +++ b/rsocket-micrometer/build.gradle @@ -0,0 +1,47 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' +} + +dependencies { + api project(':rsocket-core') + api 'io.micrometer:micrometer-observation' + api 'io.micrometer:micrometer-core' + api 'io.micrometer:micrometer-tracing' + + implementation 'org.slf4j:slf4j-api' + + testImplementation project(':rsocket-test') + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.assertj:assertj-core' + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.mockito:mockito-core' + + testRuntimeOnly 'ch.qos.logback:logback-classic' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' +} + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.micrometer") + } +} + +description = 'Transparent Metrics exposure to Micrometer' diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java new file mode 100644 index 000000000..7c7ac37b9 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnection.java @@ -0,0 +1,267 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static io.rsocket.frame.FrameType.*; + +import io.micrometer.core.instrument.*; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.FrameHeaderCodec; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import java.net.SocketAddress; +import java.util.Objects; +import java.util.function.Consumer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +/** + * An implementation of {@link DuplexConnection} that intercepts frames and gathers Micrometer + * metrics about them. + * + *

The metric is called {@code rsocket.frame} and is tagged with {@code connection.type} ({@link + * Type}), {@code frame.type} ({@link FrameType}), and any additional configured tags. {@code + * rsocket.duplex.connection.close} and {@code rsocket.duplex.connection.dispose} metrics, tagged + * with {@code connection.type} ({@link Type}) and any additional configured tags are also + * collected. + * + * @see Micrometer + */ +final class MicrometerDuplexConnection implements DuplexConnection { + + private final Counter close; + + private final DuplexConnection delegate; + + private final Counter dispose; + + private final FrameCounters frameCounters; + + /** + * Creates a new {@link DuplexConnection}. + * + * @param connectionType the type of connection being monitored + * @param delegate the {@link DuplexConnection} to delegate to + * @param meterRegistry the {@link MeterRegistry} to use + * @param tags additional tags to attach to {@link Meter}s + * @throws NullPointerException if {@code connectionType}, {@code delegate}, or {@code + * meterRegistry} is {@code null} + */ + MicrometerDuplexConnection( + Type connectionType, DuplexConnection delegate, MeterRegistry meterRegistry, Tag... tags) { + + Objects.requireNonNull(connectionType, "connectionType must not be null"); + this.delegate = Objects.requireNonNull(delegate, "delegate must not be null"); + Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); + + this.close = + meterRegistry.counter( + "rsocket.duplex.connection.close", + Tags.of(tags).and("connection.type", connectionType.name())); + this.dispose = + meterRegistry.counter( + "rsocket.duplex.connection.dispose", + Tags.of(tags).and("connection.type", connectionType.name())); + this.frameCounters = new FrameCounters(connectionType, meterRegistry, tags); + } + + @Override + public ByteBufAllocator alloc() { + return delegate.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return delegate.remoteAddress(); + } + + @Override + public void dispose() { + delegate.dispose(); + dispose.increment(); + } + + @Override + public Mono onClose() { + return delegate.onClose().doAfterTerminate(close::increment); + } + + @Override + public Flux receive() { + return delegate.receive().doOnNext(frameCounters); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + frameCounters.accept(frame); + delegate.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + delegate.sendErrorAndClose(e); + } + + private static final class FrameCounters implements Consumer { + + private final Logger logger = LoggerFactory.getLogger(this.getClass()); + + private final Counter cancel; + + private final Counter complete; + + private final Counter error; + + private final Counter extension; + + private final Counter keepalive; + + private final Counter lease; + + private final Counter metadataPush; + + private final Counter next; + + private final Counter nextComplete; + + private final Counter payload; + + private final Counter requestChannel; + + private final Counter requestFireAndForget; + + private final Counter requestN; + + private final Counter requestResponse; + + private final Counter requestStream; + + private final Counter resume; + + private final Counter resumeOk; + + private final Counter setup; + + private final Counter unknown; + + private FrameCounters(Type connectionType, MeterRegistry meterRegistry, Tag... tags) { + this.cancel = counter(connectionType, meterRegistry, CANCEL, tags); + this.complete = counter(connectionType, meterRegistry, COMPLETE, tags); + this.error = counter(connectionType, meterRegistry, ERROR, tags); + this.extension = counter(connectionType, meterRegistry, EXT, tags); + this.keepalive = counter(connectionType, meterRegistry, KEEPALIVE, tags); + this.lease = counter(connectionType, meterRegistry, LEASE, tags); + this.metadataPush = counter(connectionType, meterRegistry, METADATA_PUSH, tags); + this.next = counter(connectionType, meterRegistry, NEXT, tags); + this.nextComplete = counter(connectionType, meterRegistry, NEXT_COMPLETE, tags); + this.payload = counter(connectionType, meterRegistry, PAYLOAD, tags); + this.requestChannel = counter(connectionType, meterRegistry, REQUEST_CHANNEL, tags); + this.requestFireAndForget = counter(connectionType, meterRegistry, REQUEST_FNF, tags); + this.requestN = counter(connectionType, meterRegistry, REQUEST_N, tags); + this.requestResponse = counter(connectionType, meterRegistry, REQUEST_RESPONSE, tags); + this.requestStream = counter(connectionType, meterRegistry, REQUEST_STREAM, tags); + this.resume = counter(connectionType, meterRegistry, RESUME, tags); + this.resumeOk = counter(connectionType, meterRegistry, RESUME_OK, tags); + this.setup = counter(connectionType, meterRegistry, SETUP, tags); + this.unknown = counter(connectionType, meterRegistry, "UNKNOWN", tags); + } + + private static Counter counter( + Type connectionType, MeterRegistry meterRegistry, FrameType frameType, Tag... tags) { + + return counter(connectionType, meterRegistry, frameType.name(), tags); + } + + private static Counter counter( + Type connectionType, MeterRegistry meterRegistry, String frameType, Tag... tags) { + + return meterRegistry.counter( + "rsocket.frame", + Tags.of(tags).and("connection.type", connectionType.name()).and("frame.type", frameType)); + } + + @Override + public void accept(ByteBuf frame) { + FrameType frameType = FrameHeaderCodec.frameType(frame); + + switch (frameType) { + case SETUP: + this.setup.increment(); + break; + case LEASE: + this.lease.increment(); + break; + case KEEPALIVE: + this.keepalive.increment(); + break; + case REQUEST_RESPONSE: + this.requestResponse.increment(); + break; + case REQUEST_FNF: + this.requestFireAndForget.increment(); + break; + case REQUEST_STREAM: + this.requestStream.increment(); + break; + case REQUEST_CHANNEL: + this.requestChannel.increment(); + break; + case REQUEST_N: + this.requestN.increment(); + break; + case CANCEL: + this.cancel.increment(); + break; + case PAYLOAD: + this.payload.increment(); + break; + case ERROR: + this.error.increment(); + break; + case METADATA_PUSH: + this.metadataPush.increment(); + break; + case RESUME: + this.resume.increment(); + break; + case RESUME_OK: + this.resumeOk.increment(); + break; + case NEXT: + this.next.increment(); + break; + case COMPLETE: + this.complete.increment(); + break; + case NEXT_COMPLETE: + this.nextComplete.increment(); + break; + case EXT: + this.extension.increment(); + break; + default: + this.logger.debug("Skipping count of unknown frame type: {}", frameType); + this.unknown.increment(); + } + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptor.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptor.java new file mode 100644 index 000000000..b94e969ec --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptor.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import io.micrometer.core.instrument.Meter; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import java.util.Objects; + +/** + * An implementation of {@link DuplexConnectionInterceptor} that intercepts frames and gathers + * Micrometer metrics about them. + * + *

The metric is called {@code rsocket.frame} and is tagged with {@code connection.type} ({@link + * Type}), {@code frame.type} ({@link FrameType}), and any additional configured tags. {@code + * rsocket.duplex.connection.close} and {@code rsocket.duplex.connection.dispose} metrics, tagged + * with {@code connection.type} ({@link Type}) and any additional configured tags are also + * collected. + * + * @see Micrometer + */ +public final class MicrometerDuplexConnectionInterceptor implements DuplexConnectionInterceptor { + + private final MeterRegistry meterRegistry; + + private final Tag[] tags; + + /** + * Creates a new {@link DuplexConnectionInterceptor}. + * + * @param meterRegistry the {@link MeterRegistry} to use to create {@link Meter}s. + * @param tags the additional tags to attach to each {@link Meter} + * @throws NullPointerException if {@code meterRegistry} is {@code null} + */ + public MicrometerDuplexConnectionInterceptor(MeterRegistry meterRegistry, Tag... tags) { + this.meterRegistry = Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); + this.tags = tags; + } + + @Override + public MicrometerDuplexConnection apply(Type connectionType, DuplexConnection delegate) { + Objects.requireNonNull(connectionType, "connectionType must not be null"); + Objects.requireNonNull(delegate, "delegate must not be null"); + + return new MicrometerDuplexConnection(connectionType, delegate, meterRegistry, tags); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocket.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocket.java new file mode 100644 index 000000000..9e1abbc03 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocket.java @@ -0,0 +1,206 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static reactor.core.publisher.SignalType.CANCEL; +import static reactor.core.publisher.SignalType.ON_COMPLETE; +import static reactor.core.publisher.SignalType.ON_ERROR; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Meter; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Tags; +import io.micrometer.core.instrument.Timer; +import io.micrometer.core.instrument.Timer.Sample; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; + +/** + * An implementation of {@link RSocket} that intercepts interactions and gathers Micrometer metrics + * about them. + * + *

The metrics are called {@code rsocket.[ metadata.push | request.channel | request.fnf | + * request.response | request.stream ]} and is tagged with {@code signal.type} ({@link SignalType}) + * and any additional configured tags. + * + * @see Micrometer + */ +final class MicrometerRSocket implements RSocket { + + private final RSocket delegate; + + private final InteractionCounters metadataPush; + + private final InteractionCounters requestChannel; + + private final InteractionCounters requestFireAndForget; + + private final InteractionTimers requestResponse; + + private final InteractionCounters requestStream; + + /** + * Creates a new {@link RSocket}. + * + * @param delegate the {@link RSocket} to delegate to + * @param meterRegistry the {@link MeterRegistry} to use + * @param tags additional tags to attach to {@link Meter}s + * @throws NullPointerException if {@code delegate} or {@code meterRegistry} is {@code null} + */ + MicrometerRSocket(RSocket delegate, MeterRegistry meterRegistry, Tag... tags) { + this.delegate = Objects.requireNonNull(delegate, "delegate must not be null"); + Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); + + this.metadataPush = new InteractionCounters(meterRegistry, "metadata.push", tags); + this.requestChannel = new InteractionCounters(meterRegistry, "request.channel", tags); + this.requestFireAndForget = new InteractionCounters(meterRegistry, "request.fnf", tags); + this.requestResponse = new InteractionTimers(meterRegistry, "request.response", tags); + this.requestStream = new InteractionCounters(meterRegistry, "request.stream", tags); + } + + @Override + public void dispose() { + delegate.dispose(); + } + + @Override + public Mono fireAndForget(Payload payload) { + return delegate.fireAndForget(payload).doFinally(requestFireAndForget); + } + + @Override + public Mono metadataPush(Payload payload) { + return delegate.metadataPush(payload).doFinally(metadataPush); + } + + @Override + public Mono onClose() { + return delegate.onClose(); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return delegate.requestChannel(payloads).doFinally(requestChannel); + } + + @Override + public Mono requestResponse(Payload payload) { + return Mono.defer( + () -> { + Sample sample = requestResponse.start(); + + return delegate + .requestResponse(payload) + .doFinally(signalType -> requestResponse.accept(sample, signalType)); + }); + } + + @Override + public Flux requestStream(Payload payload) { + return delegate.requestStream(payload).doFinally(requestStream); + } + + private static final class InteractionCounters implements Consumer { + + private final Counter cancel; + + private final Counter onComplete; + + private final Counter onError; + + private InteractionCounters(MeterRegistry meterRegistry, String interactionModel, Tag... tags) { + this.cancel = counter(meterRegistry, interactionModel, CANCEL, tags); + this.onComplete = counter(meterRegistry, interactionModel, ON_COMPLETE, tags); + this.onError = counter(meterRegistry, interactionModel, ON_ERROR, tags); + } + + @Override + public void accept(SignalType signalType) { + switch (signalType) { + case CANCEL: + cancel.increment(); + break; + case ON_COMPLETE: + onComplete.increment(); + break; + case ON_ERROR: + onError.increment(); + break; + } + } + + private static Counter counter( + MeterRegistry meterRegistry, String interactionModel, SignalType signalType, Tag... tags) { + + return meterRegistry.counter( + "rsocket." + interactionModel, Tags.of(tags).and("signal.type", signalType.name())); + } + } + + private static final class InteractionTimers implements BiConsumer { + + private final Timer cancel; + + private final MeterRegistry meterRegistry; + + private final Timer onComplete; + + private final Timer onError; + + private InteractionTimers(MeterRegistry meterRegistry, String interactionModel, Tag... tags) { + this.meterRegistry = meterRegistry; + + this.cancel = timer(meterRegistry, interactionModel, CANCEL, tags); + this.onComplete = timer(meterRegistry, interactionModel, ON_COMPLETE, tags); + this.onError = timer(meterRegistry, interactionModel, ON_ERROR, tags); + } + + @Override + public void accept(Sample sample, SignalType signalType) { + switch (signalType) { + case CANCEL: + sample.stop(cancel); + break; + case ON_COMPLETE: + sample.stop(onComplete); + break; + case ON_ERROR: + sample.stop(onError); + break; + } + } + + Sample start() { + return Timer.start(meterRegistry); + } + + private static Timer timer( + MeterRegistry meterRegistry, String interactionModel, SignalType signalType, Tag... tags) { + + return meterRegistry.timer( + "rsocket." + interactionModel, Tags.of(tags).and("signal.type", signalType.name())); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocketInterceptor.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocketInterceptor.java new file mode 100644 index 000000000..c405c8601 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/MicrometerRSocketInterceptor.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import io.micrometer.core.instrument.Meter; +import io.micrometer.core.instrument.MeterRegistry; +import io.micrometer.core.instrument.Tag; +import io.rsocket.RSocket; +import io.rsocket.plugins.RSocketInterceptor; +import java.util.Objects; +import reactor.core.publisher.SignalType; + +/** + * An implementation of {@link RSocketInterceptor} that intercepts interactions and gathers + * Micrometer metrics about them. + * + *

The metrics are called {@code rsocket.[ metadata.push | request.channel | request.fnf | + * request.response | request.stream ]} and is tagged with {@code signal.type} ({@link SignalType}) + * and any additional configured tags. + * + * @see Micrometer + */ +public final class MicrometerRSocketInterceptor implements RSocketInterceptor { + + private final MeterRegistry meterRegistry; + + private final Tag[] tags; + + /** + * Creates a new {@link RSocketInterceptor}. + * + * @param meterRegistry the {@link MeterRegistry} to use to create {@link Meter}s. + * @param tags the additional tags to attach to each {@link Meter} + * @throws NullPointerException if {@code meterRegistry} is {@code null} + */ + public MicrometerRSocketInterceptor(MeterRegistry meterRegistry, Tag... tags) { + this.meterRegistry = Objects.requireNonNull(meterRegistry, "meterRegistry must not be null"); + this.tags = tags; + } + + @Override + public MicrometerRSocket apply(RSocket delegate) { + Objects.requireNonNull(delegate, "delegate must not be null"); + + return new MicrometerRSocket(delegate, meterRegistry, tags); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java new file mode 100644 index 000000000..09c8ba316 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufGetter.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.rsocket.metadata.CompositeMetadata; + +public class ByteBufGetter implements Propagator.Getter { + + @Override + public String get(ByteBuf carrier, String key) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(carrier, false); + for (CompositeMetadata.Entry entry : compositeMetadata) { + if (key.equals(entry.getMimeType())) { + return entry.getContent().toString(CharsetUtil.UTF_8); + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java new file mode 100644 index 000000000..678bdb1ed --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ByteBufSetter.java @@ -0,0 +1,33 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.metadata.CompositeMetadataCodec; + +public class ByteBufSetter implements Propagator.Setter { + + @Override + public void set(CompositeByteBuf carrier, String key, String value) { + final ByteBufAllocator alloc = carrier.alloc(); + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + carrier, alloc, key, ByteBufUtil.writeUtf8(alloc, value)); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java new file mode 100644 index 000000000..357be8f15 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/CompositeMetadataUtils.java @@ -0,0 +1,40 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.core.lang.Nullable; +import io.netty.buffer.ByteBuf; +import io.rsocket.metadata.CompositeMetadata; + +final class CompositeMetadataUtils { + + private CompositeMetadataUtils() { + throw new IllegalStateException("Can't instantiate a utility class"); + } + + @Nullable + static ByteBuf extract(ByteBuf metadata, String key) { + final CompositeMetadata compositeMetadata = new CompositeMetadata(metadata, false); + for (CompositeMetadata.Entry entry : compositeMetadata) { + final String entryKey = entry.getMimeType(); + if (key.equals(entryKey)) { + return entry.getContent(); + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java new file mode 100644 index 000000000..2c10fc78d --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketObservationConvention.java @@ -0,0 +1,49 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +class DefaultRSocketObservationConvention { + + private final RSocketContext rSocketContext; + + public DefaultRSocketObservationConvention(RSocketContext rSocketContext) { + this.rSocketContext = rSocketContext; + } + + String getName() { + if (this.rSocketContext.frameType == FrameType.REQUEST_FNF) { + return "rsocket.fnf"; + } else if (this.rSocketContext.frameType == FrameType.REQUEST_STREAM) { + return "rsocket.stream"; + } else if (this.rSocketContext.frameType == FrameType.REQUEST_CHANNEL) { + return "rsocket.channel"; + } + return "%s"; + } + + protected RSocketContext getRSocketContext() { + return this.rSocketContext; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java new file mode 100644 index 000000000..73e04b749 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketRequesterObservationConvention.java @@ -0,0 +1,62 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.KeyValues; +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public class DefaultRSocketRequesterObservationConvention + extends DefaultRSocketObservationConvention implements RSocketRequesterObservationConvention { + + public DefaultRSocketRequesterObservationConvention(RSocketContext rSocketContext) { + super(rSocketContext); + } + + @Override + public KeyValues getLowCardinalityKeyValues(RSocketContext context) { + KeyValues values = + KeyValues.of( + RSocketObservationDocumentation.ResponderTags.REQUEST_TYPE.withValue( + context.frameType.name())); + if (StringUtils.isNotBlank(context.route)) { + values = + values.and(RSocketObservationDocumentation.ResponderTags.ROUTE.withValue(context.route)); + } + return values; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext; + } + + @Override + public String getName() { + if (getRSocketContext().frameType == FrameType.REQUEST_RESPONSE) { + return "rsocket.request"; + } + return super.getName(); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java new file mode 100644 index 000000000..5318c1b37 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/DefaultRSocketResponderObservationConvention.java @@ -0,0 +1,61 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.KeyValues; +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.rsocket.frame.FrameType; + +/** + * Default {@link RSocketRequesterObservationConvention} implementation. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public class DefaultRSocketResponderObservationConvention + extends DefaultRSocketObservationConvention implements RSocketResponderObservationConvention { + + public DefaultRSocketResponderObservationConvention(RSocketContext rSocketContext) { + super(rSocketContext); + } + + @Override + public KeyValues getLowCardinalityKeyValues(RSocketContext context) { + KeyValues tags = + KeyValues.of( + RSocketObservationDocumentation.ResponderTags.REQUEST_TYPE.withValue( + context.frameType.name())); + if (StringUtils.isNotBlank(context.route)) { + tags = tags.and(RSocketObservationDocumentation.ResponderTags.ROUTE.withValue(context.route)); + } + return tags; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext; + } + + @Override + public String getName() { + if (getRSocketContext().frameType == FrameType.REQUEST_RESPONSE) { + return "rsocket.response"; + } + return super.getName(); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java new file mode 100644 index 000000000..fb80ea317 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationRequesterRSocketProxy.java @@ -0,0 +1,208 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.docs.ObservationDocumentation; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.RSocketProxy; +import java.util.Iterator; +import java.util.function.Function; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; +import reactor.util.context.ContextView; + +/** + * Tracing representation of a {@link RSocketProxy} for the requester. + * + * @author Marcin Grzejszczak + * @author Oleh Dokuka + * @since 1.1.4 + */ +public class ObservationRequesterRSocketProxy extends RSocketProxy { + + /** Aligned with ObservationThreadLocalAccessor#KEY */ + private static final String MICROMETER_OBSERVATION_KEY = "micrometer.observation"; + + private final ObservationRegistry observationRegistry; + + @Nullable private final RSocketRequesterObservationConvention observationConvention; + + public ObservationRequesterRSocketProxy(RSocket source, ObservationRegistry observationRegistry) { + this(source, observationRegistry, null); + } + + public ObservationRequesterRSocketProxy( + RSocket source, + ObservationRegistry observationRegistry, + RSocketRequesterObservationConvention observationConvention) { + super(source); + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; + } + + @Override + public Mono fireAndForget(Payload payload) { + return setObservation( + super::fireAndForget, + payload, + FrameType.REQUEST_FNF, + RSocketObservationDocumentation.RSOCKET_REQUESTER_FNF); + } + + @Override + public Mono requestResponse(Payload payload) { + return setObservation( + super::requestResponse, + payload, + FrameType.REQUEST_RESPONSE, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_RESPONSE); + } + + Mono setObservation( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation observation) { + return Mono.deferContextual( + contextView -> observe(input, payload, frameType, observation, contextView)); + } + + private String route(Payload payload) { + if (payload.hasMetadata()) { + try { + ByteBuf extracted = + CompositeMetadataUtils.extract( + payload.sliceMetadata(), WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + final RoutingMetadata routingMetadata = new RoutingMetadata(extracted); + final Iterator iterator = routingMetadata.iterator(); + return iterator.next(); + } catch (Exception e) { + + } + } + return null; + } + + private Mono observe( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation obs, + ContextView contextView) { + String route = route(payload); + RSocketContext rSocketContext = + new RSocketContext( + payload, payload.sliceMetadata(), frameType, route, RSocketContext.Side.REQUESTER); + Observation parentObservation = contextView.getOrDefault(MICROMETER_OBSERVATION_KEY, null); + Observation observation = + obs.observation( + this.observationConvention, + new DefaultRSocketRequesterObservationConvention(rSocketContext), + () -> rSocketContext, + observationRegistry) + .parentObservation(parentObservation); + setContextualName(frameType, route, observation); + observation.start(); + Payload newPayload = payload; + if (rSocketContext.modifiedPayload != null) { + newPayload = rSocketContext.modifiedPayload; + } + return input + .apply(newPayload) + .doOnError(observation::error) + .doFinally(signalType -> observation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, observation)); + } + + @Override + public Flux requestStream(Payload payload) { + return observationFlux( + super::requestStream, + payload, + FrameType.REQUEST_STREAM, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_STREAM); + } + + @Override + public Flux requestChannel(Publisher inbound) { + return Flux.from(inbound) + .switchOnFirst( + (firstSignal, flux) -> { + final Payload firstPayload = firstSignal.get(); + if (firstPayload != null) { + return observationFlux( + p -> super.requestChannel(flux.skip(1).startWith(p)), + firstPayload, + FrameType.REQUEST_CHANNEL, + RSocketObservationDocumentation.RSOCKET_REQUESTER_REQUEST_CHANNEL); + } + return flux; + }); + } + + private Flux observationFlux( + Function> input, + Payload payload, + FrameType frameType, + ObservationDocumentation obs) { + return Flux.deferContextual( + contextView -> { + String route = route(payload); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + frameType, + route, + RSocketContext.Side.REQUESTER); + Observation parentObservation = + contextView.getOrDefault(MICROMETER_OBSERVATION_KEY, null); + Observation newObservation = + obs.observation( + this.observationConvention, + new DefaultRSocketRequesterObservationConvention(rSocketContext), + () -> rSocketContext, + this.observationRegistry) + .parentObservation(parentObservation); + setContextualName(frameType, route, newObservation); + newObservation.start(); + return input + .apply(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + }); + } + + private void setContextualName(FrameType frameType, String route, Observation newObservation) { + if (StringUtils.isNotBlank(route)) { + newObservation.contextualName(frameType.name() + " " + route); + } else { + newObservation.contextualName(frameType.name()); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java new file mode 100644 index 000000000..9ed27adf3 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/ObservationResponderRSocketProxy.java @@ -0,0 +1,179 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.util.StringUtils; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.RSocketProxy; +import java.util.Iterator; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +/** + * Tracing representation of a {@link RSocketProxy} for the responder. + * + * @author Marcin Grzejszczak + * @author Oleh Dokuka + * @since 1.1.4 + */ +public class ObservationResponderRSocketProxy extends RSocketProxy { + /** Aligned with ObservationThreadLocalAccessor#KEY */ + private static final String MICROMETER_OBSERVATION_KEY = "micrometer.observation"; + + private final ObservationRegistry observationRegistry; + + @Nullable private final RSocketResponderObservationConvention observationConvention; + + public ObservationResponderRSocketProxy(RSocket source, ObservationRegistry observationRegistry) { + this(source, observationRegistry, null); + } + + public ObservationResponderRSocketProxy( + RSocket source, + ObservationRegistry observationRegistry, + RSocketResponderObservationConvention observationConvention) { + super(source); + this.observationRegistry = observationRegistry; + this.observationConvention = observationConvention; + } + + @Override + public Mono fireAndForget(Payload payload) { + // called on Netty EventLoop + // there can't be observation in thread local here + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + FrameType.REQUEST_FNF, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation(RSocketObservationDocumentation.RSOCKET_RESPONDER_FNF, rSocketContext); + return super.fireAndForget(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + private Observation startObservation( + RSocketObservationDocumentation observation, RSocketContext rSocketContext) { + return observation.start( + this.observationConvention, + new DefaultRSocketResponderObservationConvention(rSocketContext), + () -> rSocketContext, + this.observationRegistry); + } + + @Override + public Mono requestResponse(Payload payload) { + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, + payload.sliceMetadata(), + FrameType.REQUEST_RESPONSE, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_RESPONSE, rSocketContext); + return super.requestResponse(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + @Override + public Flux requestStream(Payload payload) { + ByteBuf sliceMetadata = payload.sliceMetadata(); + String route = route(payload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + payload, sliceMetadata, FrameType.REQUEST_STREAM, route, RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_STREAM, rSocketContext); + return super.requestStream(rSocketContext.modifiedPayload) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite(context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + + @Override + public Flux requestChannel(Publisher payloads) { + return Flux.from(payloads) + .switchOnFirst( + (firstSignal, flux) -> { + final Payload firstPayload = firstSignal.get(); + if (firstPayload != null) { + ByteBuf sliceMetadata = firstPayload.sliceMetadata(); + String route = route(firstPayload, sliceMetadata); + RSocketContext rSocketContext = + new RSocketContext( + firstPayload, + firstPayload.sliceMetadata(), + FrameType.REQUEST_CHANNEL, + route, + RSocketContext.Side.RESPONDER); + Observation newObservation = + startObservation( + RSocketObservationDocumentation.RSOCKET_RESPONDER_REQUEST_CHANNEL, + rSocketContext); + if (StringUtils.isNotBlank(route)) { + newObservation.contextualName(rSocketContext.frameType.name() + " " + route); + } + return super.requestChannel(flux.skip(1).startWith(rSocketContext.modifiedPayload)) + .doOnError(newObservation::error) + .doFinally(signalType -> newObservation.stop()) + .contextWrite( + context -> context.put(MICROMETER_OBSERVATION_KEY, newObservation)); + } + return flux; + }); + } + + private String route(Payload payload, ByteBuf headers) { + if (payload.hasMetadata()) { + try { + final ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + if (extract != null) { + final RoutingMetadata routingMetadata = new RoutingMetadata(extract); + final Iterator iterator = routingMetadata.iterator(); + return iterator.next(); + } + } catch (Exception e) { + + } + } + return null; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java new file mode 100644 index 000000000..e5286a53f --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/PayloadUtils.java @@ -0,0 +1,73 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.CompositeMetadata; +import io.rsocket.metadata.CompositeMetadata.Entry; +import io.rsocket.metadata.CompositeMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import io.rsocket.util.ByteBufPayload; +import io.rsocket.util.DefaultPayload; +import java.util.HashSet; +import java.util.Set; + +final class PayloadUtils { + + private PayloadUtils() { + throw new IllegalStateException("Can't instantiate a utility class"); + } + + static CompositeByteBuf cleanTracingMetadata(Payload payload, Set fields) { + Set fieldsWithDefaultZipkin = new HashSet<>(fields); + fieldsWithDefaultZipkin.add(WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN.getString()); + final CompositeByteBuf metadata = ByteBufAllocator.DEFAULT.compositeBuffer(); + if (payload.hasMetadata()) { + try { + final CompositeMetadata entries = new CompositeMetadata(payload.metadata(), false); + for (Entry entry : entries) { + if (!fieldsWithDefaultZipkin.contains(entry.getMimeType())) { + CompositeMetadataCodec.encodeAndAddMetadataWithCompression( + metadata, + ByteBufAllocator.DEFAULT, + entry.getMimeType(), + entry.getContent().retain()); + } + } + } catch (Exception e) { + + } + } + return metadata; + } + + static Payload payload(Payload payload, CompositeByteBuf metadata) { + final Payload newPayload; + try { + if (payload instanceof ByteBufPayload) { + newPayload = ByteBufPayload.create(payload.data().retain(), metadata); + } else { + newPayload = DefaultPayload.create(payload.data().retain(), metadata); + } + } finally { + payload.release(); + } + return newPayload; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java new file mode 100644 index 000000000..8622cdfa5 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketContext.java @@ -0,0 +1,76 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.lang.Nullable; +import io.micrometer.observation.Observation; +import io.netty.buffer.ByteBuf; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; + +public class RSocketContext extends Observation.Context { + + final Payload payload; + + final ByteBuf metadata; + + final FrameType frameType; + + final String route; + + final Side side; + + Payload modifiedPayload; + + RSocketContext( + Payload payload, ByteBuf metadata, FrameType frameType, @Nullable String route, Side side) { + this.payload = payload; + this.metadata = metadata; + this.frameType = frameType; + this.route = route; + this.side = side; + } + + public enum Side { + REQUESTER, + RESPONDER + } + + public Payload getPayload() { + return payload; + } + + public ByteBuf getMetadata() { + return metadata; + } + + public FrameType getFrameType() { + return frameType; + } + + public String getRoute() { + return route; + } + + public Side getSide() { + return side; + } + + public Payload getModifiedPayload() { + return modifiedPayload; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java new file mode 100644 index 000000000..1be6b4599 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketObservationDocumentation.java @@ -0,0 +1,232 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.common.docs.KeyName; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; +import io.micrometer.observation.docs.ObservationDocumentation; + +enum RSocketObservationDocumentation implements ObservationDocumentation { + + /** Observation created on the RSocket responder side. */ + RSOCKET_RESPONDER { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + }, + + /** Observation created on the RSocket requester side for Fire and Forget frame type. */ + RSOCKET_REQUESTER_FNF { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Fire and Forget frame type. */ + RSOCKET_RESPONDER_FNF { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Response frame type. */ + RSOCKET_REQUESTER_REQUEST_RESPONSE { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Response frame type. */ + RSOCKET_RESPONDER_REQUEST_RESPONSE { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Stream frame type. */ + RSOCKET_REQUESTER_REQUEST_STREAM { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Stream frame type. */ + RSOCKET_RESPONDER_REQUEST_STREAM { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket requester side for Request Channel frame type. */ + RSOCKET_REQUESTER_REQUEST_CHANNEL { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketRequesterObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return RequesterTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }, + + /** Observation created on the RSocket responder side for Request Channel frame type. */ + RSOCKET_RESPONDER_REQUEST_CHANNEL { + @Override + public Class> + getDefaultConvention() { + return DefaultRSocketResponderObservationConvention.class; + } + + @Override + public KeyName[] getLowCardinalityKeyNames() { + return ResponderTags.values(); + } + + @Override + public String getPrefix() { + return "rsocket."; + } + }; + + enum RequesterTags implements KeyName { + + /** Name of the RSocket route. */ + ROUTE { + @Override + public String asString() { + return "rsocket.route"; + } + }, + + /** Name of the RSocket request type. */ + REQUEST_TYPE { + @Override + public String asString() { + return "rsocket.request-type"; + } + }, + + /** Name of the RSocket content type. */ + CONTENT_TYPE { + @Override + public String asString() { + return "rsocket.content-type"; + } + } + } + + enum ResponderTags implements KeyName { + + /** Name of the RSocket route. */ + ROUTE { + @Override + public String asString() { + return "rsocket.route"; + } + }, + + /** Name of the RSocket request type. */ + REQUEST_TYPE { + @Override + public String asString() { + return "rsocket.request-type"; + } + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java new file mode 100644 index 000000000..d795f81b5 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterObservationConvention.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; + +/** + * {@link ObservationConvention} for RSocket requester {@link RSocketContext}. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public interface RSocketRequesterObservationConvention + extends ObservationConvention { + + @Override + default boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.REQUESTER; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java new file mode 100644 index 000000000..996267d4a --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketRequesterTracingObservationHandler.java @@ -0,0 +1,131 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.micrometer.tracing.internal.EncodingUtils; +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.metadata.TracingMetadataCodec; +import java.util.HashSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RSocketRequesterTracingObservationHandler + implements TracingObservationHandler { + private static final Logger log = + LoggerFactory.getLogger(RSocketRequesterTracingObservationHandler.class); + + private final Propagator propagator; + + private final Propagator.Setter setter; + + private final Tracer tracer; + + private final boolean isZipkinPropagationEnabled; + + public RSocketRequesterTracingObservationHandler( + Tracer tracer, + Propagator propagator, + Propagator.Setter setter, + boolean isZipkinPropagationEnabled) { + this.tracer = tracer; + this.propagator = propagator; + this.setter = setter; + this.isZipkinPropagationEnabled = isZipkinPropagationEnabled; + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.REQUESTER; + } + + @Override + public Tracer getTracer() { + return this.tracer; + } + + @Override + public void onStart(RSocketContext context) { + Payload payload = context.payload; + Span.Builder spanBuilder = this.tracer.spanBuilder(); + Span parentSpan = getParentSpan(context); + if (parentSpan != null) { + spanBuilder.setParent(parentSpan.context()); + } + Span span = spanBuilder.kind(Span.Kind.PRODUCER).start(); + log.debug("Extracted result from context or thread local {}", span); + // TODO: newmetadata returns an empty composite byte buf + final CompositeByteBuf newMetadata = + PayloadUtils.cleanTracingMetadata(payload, new HashSet<>(propagator.fields())); + TraceContext traceContext = span.context(); + if (this.isZipkinPropagationEnabled) { + injectDefaultZipkinRSocketHeaders(newMetadata, traceContext); + } + this.propagator.inject(traceContext, newMetadata, this.setter); + context.modifiedPayload = PayloadUtils.payload(payload, newMetadata); + getTracingContext(context).setSpan(span); + } + + @Override + public void onError(RSocketContext context) { + Throwable error = context.getError(); + if (error != null) { + getRequiredSpan(context).error(error); + } + } + + @Override + public void onStop(RSocketContext context) { + Span span = getRequiredSpan(context); + tagSpan(context, span); + span.name(context.getContextualName()).end(); + } + + private void injectDefaultZipkinRSocketHeaders( + CompositeByteBuf newMetadata, TraceContext traceContext) { + TracingMetadataCodec.Flags flags = + traceContext.sampled() == null + ? TracingMetadataCodec.Flags.UNDECIDED + : traceContext.sampled() + ? TracingMetadataCodec.Flags.SAMPLE + : TracingMetadataCodec.Flags.NOT_SAMPLE; + String traceId = traceContext.traceId(); + long[] traceIds = EncodingUtils.fromString(traceId); + long[] spanId = EncodingUtils.fromString(traceContext.spanId()); + long[] parentSpanId = EncodingUtils.fromString(traceContext.parentId()); + boolean isTraceId128Bit = traceIds.length == 2; + if (isTraceId128Bit) { + TracingMetadataCodec.encode128( + newMetadata.alloc(), + traceIds[0], + traceIds[1], + spanId[0], + EncodingUtils.fromString(traceContext.parentId())[0], + flags); + } else { + TracingMetadataCodec.encode64( + newMetadata.alloc(), traceIds[0], spanId[0], parentSpanId[0], flags); + } + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java new file mode 100644 index 000000000..a5d6808bd --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderObservationConvention.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationConvention; + +/** + * {@link ObservationConvention} for RSocket responder {@link RSocketContext}. + * + * @author Marcin Grzejszczak + * @since 1.1.4 + */ +public interface RSocketResponderObservationConvention + extends ObservationConvention { + + @Override + default boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.RESPONDER; + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java new file mode 100644 index 000000000..e3975b577 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/observation/RSocketResponderTracingObservationHandler.java @@ -0,0 +1,152 @@ +/* + * Copyright 2013-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer.observation; + +import io.micrometer.observation.Observation; +import io.micrometer.tracing.Span; +import io.micrometer.tracing.TraceContext; +import io.micrometer.tracing.Tracer; +import io.micrometer.tracing.handler.TracingObservationHandler; +import io.micrometer.tracing.internal.EncodingUtils; +import io.micrometer.tracing.propagation.Propagator; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.rsocket.Payload; +import io.rsocket.frame.FrameType; +import io.rsocket.metadata.RoutingMetadata; +import io.rsocket.metadata.TracingMetadata; +import io.rsocket.metadata.TracingMetadataCodec; +import io.rsocket.metadata.WellKnownMimeType; +import java.util.HashSet; +import java.util.Iterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RSocketResponderTracingObservationHandler + implements TracingObservationHandler { + + private static final Logger log = + LoggerFactory.getLogger(RSocketResponderTracingObservationHandler.class); + + private final Propagator propagator; + + private final Propagator.Getter getter; + + private final Tracer tracer; + + private final boolean isZipkinPropagationEnabled; + + public RSocketResponderTracingObservationHandler( + Tracer tracer, + Propagator propagator, + Propagator.Getter getter, + boolean isZipkinPropagationEnabled) { + this.tracer = tracer; + this.propagator = propagator; + this.getter = getter; + this.isZipkinPropagationEnabled = isZipkinPropagationEnabled; + } + + @Override + public void onStart(RSocketContext context) { + Span handle = consumerSpanBuilder(context.payload, context.metadata, context.frameType); + CompositeByteBuf bufs = + PayloadUtils.cleanTracingMetadata(context.payload, new HashSet<>(propagator.fields())); + context.modifiedPayload = PayloadUtils.payload(context.payload, bufs); + getTracingContext(context).setSpan(handle); + } + + @Override + public void onError(RSocketContext context) { + Throwable error = context.getError(); + if (error != null) { + getRequiredSpan(context).error(error); + } + } + + @Override + public void onStop(RSocketContext context) { + Span span = getRequiredSpan(context); + tagSpan(context, span); + span.end(); + } + + @Override + public boolean supportsContext(Observation.Context context) { + return context instanceof RSocketContext + && ((RSocketContext) context).side == RSocketContext.Side.RESPONDER; + } + + @Override + public Tracer getTracer() { + return this.tracer; + } + + private Span consumerSpanBuilder(Payload payload, ByteBuf headers, FrameType requestType) { + Span.Builder consumerSpanBuilder = consumerSpanBuilder(payload, headers); + log.debug("Extracted result from headers {}", consumerSpanBuilder); + String name = "handle"; + if (payload.hasMetadata()) { + try { + final ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()); + if (extract != null) { + final RoutingMetadata routingMetadata = new RoutingMetadata(extract); + final Iterator iterator = routingMetadata.iterator(); + name = requestType.name() + " " + iterator.next(); + } + } catch (Exception e) { + + } + } + return consumerSpanBuilder.kind(Span.Kind.CONSUMER).name(name).start(); + } + + private Span.Builder consumerSpanBuilder(Payload payload, ByteBuf headers) { + if (this.isZipkinPropagationEnabled && payload.hasMetadata()) { + try { + ByteBuf extract = + CompositeMetadataUtils.extract( + headers, WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN.getString()); + if (extract != null) { + TracingMetadata tracingMetadata = TracingMetadataCodec.decode(extract); + Span.Builder builder = this.tracer.spanBuilder(); + String traceId = EncodingUtils.fromLong(tracingMetadata.traceId()); + long traceIdHigh = tracingMetadata.traceIdHigh(); + if (traceIdHigh != 0L) { + // ExtendedTraceId + traceId = EncodingUtils.fromLong(traceIdHigh) + traceId; + } + TraceContext.Builder parentBuilder = + this.tracer + .traceContextBuilder() + .sampled(tracingMetadata.isDebug() || tracingMetadata.isSampled()) + .traceId(traceId) + .spanId(EncodingUtils.fromLong(tracingMetadata.spanId())) + .parentId(EncodingUtils.fromLong(tracingMetadata.parentId())); + return builder.setParent(parentBuilder.build()); + } else { + return this.propagator.extract(headers, this.getter); + } + } catch (Exception e) { + + } + } + return this.propagator.extract(headers, this.getter); + } +} diff --git a/rsocket-micrometer/src/main/java/io/rsocket/micrometer/package-info.java b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/package-info.java new file mode 100644 index 000000000..c95f2ce02 --- /dev/null +++ b/rsocket-micrometer/src/main/java/io/rsocket/micrometer/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Transparent metrics exposure for Micrometer. + * + * @see Micrometer + */ +@NonNullApi +package io.rsocket.micrometer; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptorTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptorTest.java new file mode 100644 index 000000000..4ff072252 --- /dev/null +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionInterceptorTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static io.rsocket.plugins.DuplexConnectionInterceptor.Type.CLIENT; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.mock; + +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.rsocket.DuplexConnection; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class MicrometerDuplexConnectionInterceptorTest { + + private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + + private final SimpleMeterRegistry meterRegistry = new SimpleMeterRegistry(); + + @DisplayName("creates MicrometerDuplexConnection") + @Test + void apply() { + assertThat(new MicrometerDuplexConnectionInterceptor(meterRegistry).apply(CLIENT, delegate)) + .isInstanceOf(MicrometerDuplexConnection.class); + } + + @DisplayName("apply throws NullPointerException with null connectionType") + @Test + void applyNullConnectionType() { + assertThatNullPointerException() + .isThrownBy( + () -> new MicrometerDuplexConnectionInterceptor(meterRegistry).apply(null, delegate)) + .withMessage("connectionType must not be null"); + } + + @DisplayName("apply throws NullPointerException with null delegate") + @Test + void applyNullDelegate() { + assertThatNullPointerException() + .isThrownBy( + () -> new MicrometerDuplexConnectionInterceptor(meterRegistry).apply(CLIENT, null)) + .withMessage("delegate must not be null"); + } + + @DisplayName("constructor throws NullPointer exception with null meterRegistry") + @Test + void constructorNullMeterRegistry() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerDuplexConnectionInterceptor(null)) + .withMessage("meterRegistry must not be null"); + } +} diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java new file mode 100644 index 000000000..7806200dd --- /dev/null +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerDuplexConnectionTest.java @@ -0,0 +1,201 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static io.rsocket.frame.FrameType.*; +import static io.rsocket.plugins.DuplexConnectionInterceptor.Type.CLIENT; +import static io.rsocket.plugins.DuplexConnectionInterceptor.Type.SERVER; +import static io.rsocket.test.TestFrames.*; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.*; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.netty.buffer.ByteBuf; +import io.rsocket.DuplexConnection; +import io.rsocket.frame.FrameType; +import io.rsocket.plugins.DuplexConnectionInterceptor.Type; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.test.StepVerifier; + +final class MicrometerDuplexConnectionTest { + + private final DuplexConnection delegate = mock(DuplexConnection.class, RETURNS_SMART_NULLS); + + private final SimpleMeterRegistry meterRegistry = new SimpleMeterRegistry(); + + @DisplayName("constructor throws NullPointerException with null connectionType") + @Test + void constructorNullConnectionType() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerDuplexConnection(null, delegate, meterRegistry)) + .withMessage("connectionType must not be null"); + } + + @DisplayName("constructor throws NullPointerException with null delegate") + @Test + void constructorNullDelegate() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerDuplexConnection(CLIENT, null, meterRegistry)) + .withMessage("delegate must not be null"); + } + + @DisplayName("constructor throws NullPointerException with null meterRegistry") + @Test + void constructorNullMeterRegistry() { + + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerDuplexConnection(CLIENT, delegate, null)) + .withMessage("meterRegistry must not be null"); + } + + @DisplayName("dispose gathers metrics") + @Test + void dispose() { + new MicrometerDuplexConnection( + CLIENT, delegate, meterRegistry, Tag.of("test-key", "test-value")) + .dispose(); + + assertThat( + meterRegistry + .get("rsocket.duplex.connection.dispose") + .tag("connection.type", CLIENT.name()) + .tag("test-key", "test-value") + .counter() + .count()) + .isEqualTo(1); + } + + @DisplayName("onClose gathers metrics") + @Test + void onClose() { + when(delegate.onClose()).thenReturn(Mono.empty()); + + new MicrometerDuplexConnection( + CLIENT, delegate, meterRegistry, Tag.of("test-key", "test-value")) + .onClose() + .subscribe(Operators.drainSubscriber()); + + assertThat( + meterRegistry + .get("rsocket.duplex.connection.close") + .tag("connection.type", CLIENT.name()) + .tag("test-key", "test-value") + .counter() + .count()) + .isEqualTo(1); + } + + @DisplayName("receive gathers metrics") + @Test + void receive() { + Flux frames = + Flux.just( + createTestCancelFrame(), + createTestErrorFrame(), + createTestKeepaliveFrame(), + createTestLeaseFrame(), + createTestMetadataPushFrame(), + createTestPayloadFrame(), + createTestRequestChannelFrame(), + createTestRequestFireAndForgetFrame(), + createTestRequestNFrame(), + createTestRequestResponseFrame(), + createTestRequestStreamFrame(), + createTestSetupFrame()); + + when(delegate.receive()).thenReturn(frames); + + new MicrometerDuplexConnection( + CLIENT, delegate, meterRegistry, Tag.of("test-key", "test-value")) + .receive() + .as(StepVerifier::create) + .expectNextCount(12) + .verifyComplete(); + + assertThat(findCounter(CLIENT, CANCEL).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, COMPLETE).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, ERROR).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, KEEPALIVE).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, LEASE).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, METADATA_PUSH).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_CHANNEL).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_FNF).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_N).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_RESPONSE).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, REQUEST_STREAM).count()).isEqualTo(1); + assertThat(findCounter(CLIENT, SETUP).count()).isEqualTo(1); + } + + @DisplayName("send gathers metrics") + @SuppressWarnings("unchecked") + @Test + void send() { + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + doNothing().when(delegate).sendFrame(Mockito.anyInt(), captor.capture()); + + final MicrometerDuplexConnection micrometerDuplexConnection = + new MicrometerDuplexConnection( + SERVER, delegate, meterRegistry, Tag.of("test-key", "test-value")); + micrometerDuplexConnection.sendFrame(1, createTestCancelFrame()); + micrometerDuplexConnection.sendFrame(1, createTestErrorFrame()); + micrometerDuplexConnection.sendFrame(1, createTestKeepaliveFrame()); + micrometerDuplexConnection.sendFrame(1, createTestLeaseFrame()); + micrometerDuplexConnection.sendFrame(1, createTestMetadataPushFrame()); + micrometerDuplexConnection.sendFrame(1, createTestPayloadFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestChannelFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestFireAndForgetFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestNFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestResponseFrame()); + micrometerDuplexConnection.sendFrame(1, createTestRequestStreamFrame()); + micrometerDuplexConnection.sendFrame(1, createTestSetupFrame()); + + StepVerifier.create(Flux.fromIterable(captor.getAllValues())) + .expectNextCount(12) + .verifyComplete(); + + assertThat(findCounter(SERVER, CANCEL).count()).isEqualTo(1); + assertThat(findCounter(SERVER, COMPLETE).count()).isEqualTo(1); + assertThat(findCounter(SERVER, ERROR).count()).isEqualTo(1); + assertThat(findCounter(SERVER, KEEPALIVE).count()).isEqualTo(1); + assertThat(findCounter(SERVER, LEASE).count()).isEqualTo(1); + assertThat(findCounter(SERVER, METADATA_PUSH).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_CHANNEL).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_FNF).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_N).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_RESPONSE).count()).isEqualTo(1); + assertThat(findCounter(SERVER, REQUEST_STREAM).count()).isEqualTo(1); + assertThat(findCounter(SERVER, SETUP).count()).isEqualTo(1); + } + + private Counter findCounter(Type connectionType, FrameType frameType) { + return meterRegistry + .get("rsocket.frame") + .tag("connection.type", connectionType.name()) + .tag("frame.type", frameType.name()) + .tag("test-key", "test-value") + .counter(); + } +} diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketInterceptorTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketInterceptorTest.java new file mode 100644 index 000000000..196ee1aa6 --- /dev/null +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketInterceptorTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.mock; + +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.rsocket.RSocket; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class MicrometerRSocketInterceptorTest { + + private final RSocket delegate = mock(RSocket.class, RETURNS_SMART_NULLS); + + private final SimpleMeterRegistry meterRegistry = new SimpleMeterRegistry(); + + @DisplayName("creates MicrometerRSocket") + @Test + void apply() { + assertThat(new MicrometerRSocketInterceptor(meterRegistry).apply(delegate)) + .isInstanceOf(MicrometerRSocket.class); + } + + @DisplayName("apply throws NullPointerException with null delegate") + @Test + void applyNullDelegate() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerRSocketInterceptor(meterRegistry).apply(null)) + .withMessage("delegate must not be null"); + } + + @DisplayName("constructor throws NullPointerException with null meterRegistry") + @Test + void constructorNullMeterRegistry() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerRSocketInterceptor(null)) + .withMessage("meterRegistry must not be null"); + } +} diff --git a/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketTest.java b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketTest.java new file mode 100644 index 000000000..7317c5c59 --- /dev/null +++ b/rsocket-micrometer/src/test/java/io/rsocket/micrometer/MicrometerRSocketTest.java @@ -0,0 +1,146 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.micrometer; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.Mockito.RETURNS_SMART_NULLS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Tag; +import io.micrometer.core.instrument.Timer; +import io.micrometer.core.instrument.simple.SimpleMeterRegistry; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.DefaultPayload; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.SignalType; +import reactor.test.StepVerifier; + +final class MicrometerRSocketTest { + + private final RSocket delegate = mock(RSocket.class, RETURNS_SMART_NULLS); + + private final SimpleMeterRegistry meterRegistry = new SimpleMeterRegistry(); + + @DisplayName("constructor throws NullPointerException with null delegate") + @Test + void constructorNullDelegate() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerRSocket(null, meterRegistry)) + .withMessage("delegate must not be null"); + } + + @DisplayName("constructor throws NullPointerException with null meterRegistry") + @Test + void constructorNullMeterRegistry() { + assertThatNullPointerException() + .isThrownBy(() -> new MicrometerRSocket(delegate, null)) + .withMessage("meterRegistry must not be null"); + } + + @DisplayName("fireAndForget gathers metrics") + @Test + void fireAndForget() { + Payload payload = DefaultPayload.create("test-metadata", "test-data"); + when(delegate.fireAndForget(payload)).thenReturn(Mono.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .fireAndForget(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findCounter("request.fnf", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + @DisplayName("metadataPush gathers metrics") + @Test + void metadataPush() { + Payload payload = DefaultPayload.create("test-metadata", "test-data"); + when(delegate.metadataPush(payload)).thenReturn(Mono.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .metadataPush(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findCounter("metadata.push", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + @DisplayName("requestChannel gathers metrics") + @Test + void requestChannel() { + Mono payload = Mono.just(DefaultPayload.create("test-metadata", "test-data")); + when(delegate.requestChannel(payload)).thenReturn(Flux.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .requestChannel(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findCounter("request.channel", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + @DisplayName("requestResponse gathers metrics") + @Test + void requestResponse() { + Payload payload = DefaultPayload.create("test-metadata", "test-data"); + when(delegate.requestResponse(payload)).thenReturn(Mono.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .requestResponse(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findTimer("request.response", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + @DisplayName("requestStream gathers metrics") + @Test + void requestStream() { + Payload payload = DefaultPayload.create("test-metadata", "test-data"); + when(delegate.requestStream(payload)).thenReturn(Flux.empty()); + + new MicrometerRSocket(delegate, meterRegistry, Tag.of("test-key", "test-value")) + .requestStream(payload) + .as(StepVerifier::create) + .verifyComplete(); + + assertThat(findCounter("request.stream", SignalType.ON_COMPLETE).count()).isEqualTo(1); + } + + private Counter findCounter(String interactionModel, SignalType signalType) { + return meterRegistry + .get(String.format("rsocket.%s", interactionModel)) + .tag("signal.type", signalType.name()) + .tag("test-key", "test-value") + .counter(); + } + + private Timer findTimer(String interactionModel, SignalType signalType) { + return meterRegistry + .get(String.format("rsocket.%s", interactionModel)) + .tag("signal.type", signalType.name()) + .tag("test-key", "test-value") + .timer(); + } +} diff --git a/rsocket-micrometer/src/test/resources/logback-test.xml b/rsocket-micrometer/src/test/resources/logback-test.xml new file mode 100644 index 000000000..56e2f9c9b --- /dev/null +++ b/rsocket-micrometer/src/test/resources/logback-test.xml @@ -0,0 +1,32 @@ + + + + + + + + %date{HH:mm:ss.SSS} %-10thread %-42logger %msg%n + + + + + + + + + + diff --git a/rsocket-test/build.gradle b/rsocket-test/build.gradle new file mode 100644 index 000000000..bcdf88f28 --- /dev/null +++ b/rsocket-test/build.gradle @@ -0,0 +1,41 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' +} + +dependencies { + api project(':rsocket-core') + api 'org.hdrhistogram:HdrHistogram' + api 'org.junit.jupiter:junit-jupiter-api' + + implementation 'io.projectreactor:reactor-test' + implementation 'org.assertj:assertj-core' + implementation 'org.mockito:mockito-core' + implementation 'org.awaitility:awaitility' + implementation 'org.slf4j:slf4j-api' +} + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.test") + } +} + +description = 'Test utilities for RSocket projects' diff --git a/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java new file mode 100644 index 000000000..e773b4a0d --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/BaseClientServerTest.java @@ -0,0 +1,274 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.Payload; +import io.rsocket.util.DefaultPayload; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import reactor.core.publisher.Flux; + +public abstract class BaseClientServerTest> { + public final T setup = createClientServer(); + + protected abstract T createClientServer(); + + @BeforeEach + public void init() { + setup.init(); + } + + @AfterEach + public void teardown() { + setup.tearDown(); + } + + @Test + @Timeout(10000) + public void testFireNForget10() { + long outputCount = + Flux.range(1, 10) + .flatMap(i -> setup.getRSocket().fireAndForget(testPayload(i))) + .doOnError(Throwable::printStackTrace) + .count() + .block(); + + assertThat(outputCount).isZero(); + } + + @Test + @Timeout(10000) + public void testPushMetadata10() { + long outputCount = + Flux.range(1, 10) + .flatMap(i -> setup.getRSocket().metadataPush(DefaultPayload.create("", "metadata"))) + .doOnError(Throwable::printStackTrace) + .count() + .block(); + + assertThat(outputCount).isZero(); + } + + @Test // (timeout = 10000) + public void testRequestResponse1() { + long outputCount = + Flux.range(1, 1) + .flatMap( + i -> setup.getRSocket().requestResponse(testPayload(i)).map(Payload::getDataUtf8)) + .doOnError(Throwable::printStackTrace) + .count() + .block(); + + assertThat(outputCount).isZero(); + } + + @Test + @Timeout(10000) + public void testRequestResponse10() { + long outputCount = + Flux.range(1, 10) + .flatMap( + i -> setup.getRSocket().requestResponse(testPayload(i)).map(Payload::getDataUtf8)) + .doOnError(Throwable::printStackTrace) + .count() + .block(); + + assertThat(outputCount).isEqualTo(10); + } + + private Payload testPayload(int metadataPresent) { + String metadata; + switch (metadataPresent % 5) { + case 0: + metadata = null; + break; + case 1: + metadata = ""; + break; + default: + metadata = "metadata"; + break; + } + return DefaultPayload.create("hello", metadata); + } + + @Test + @Timeout(10000) + public void testRequestResponse100() { + long outputCount = + Flux.range(1, 100) + .flatMap( + i -> setup.getRSocket().requestResponse(testPayload(i)).map(Payload::getDataUtf8)) + .doOnError(Throwable::printStackTrace) + .count() + .block(); + + assertThat(outputCount).isEqualTo(100); + } + + @Test + @Timeout(20000) + public void testRequestResponse10_000() { + long outputCount = + Flux.range(1, 10_000) + .flatMap( + i -> setup.getRSocket().requestResponse(testPayload(i)).map(Payload::getDataUtf8)) + .doOnError(Throwable::printStackTrace) + .count() + .block(); + + assertThat(outputCount).isEqualTo(10_000); + } + + @Test + @Timeout(10000) + public void testRequestStream() { + Flux publisher = setup.getRSocket().requestStream(testPayload(3)); + + long count = publisher.take(5).count().block(); + + assertThat(count).isEqualTo(5); + } + + @Test + @Timeout(10000) + public void testRequestStreamAll() { + Flux publisher = setup.getRSocket().requestStream(testPayload(3)); + + long count = publisher.count().block(); + + assertThat(count).isEqualTo(10000); + } + + @Test + @Timeout(10000) + public void testRequestStreamWithRequestN() { + CountdownBaseSubscriber ts = new CountdownBaseSubscriber(); + ts.expect(5); + + setup.getRSocket().requestStream(testPayload(3)).subscribe(ts); + + ts.await(); + assertThat(ts.count()).isEqualTo(5); + + ts.expect(5); + ts.await(); + ts.cancel(); + + assertThat(ts.count()).isEqualTo(10); + } + + @Test + @Timeout(10000) + public void testRequestStreamWithDelayedRequestN() { + CountdownBaseSubscriber ts = new CountdownBaseSubscriber(); + + setup.getRSocket().requestStream(testPayload(3)).subscribe(ts); + + ts.expect(5); + + ts.await(); + assertThat(ts.count()).isEqualTo(5); + + ts.expect(5); + ts.await(); + ts.cancel(); + + assertThat(ts.count()).isEqualTo(10); + } + + @Test + @Timeout(10000) + public void testChannel0() { + Flux publisher = setup.getRSocket().requestChannel(Flux.empty()); + + long count = publisher.count().block(); + + assertThat(count).isZero(); + } + + @Test + @Timeout(10000) + public void testChannel1() { + Flux publisher = setup.getRSocket().requestChannel(Flux.just(testPayload(0))); + + long count = publisher.count().block(); + + assertThat(count).isOne(); + } + + @Test + @Timeout(10000) + public void testChannel3() { + Flux publisher = + setup + .getRSocket() + .requestChannel(Flux.just(testPayload(0), testPayload(1), testPayload(2))); + + long count = publisher.count().block(); + + assertThat(count).isEqualTo(3); + } + + @Test + @Timeout(10000) + public void testChannel512() { + Flux payloads = Flux.range(1, 512).map(i -> DefaultPayload.create("hello " + i)); + + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertThat(count).isEqualTo(512); + } + + @Test + @Timeout(30000) + public void testChannel20_000() { + Flux payloads = Flux.range(1, 20_000).map(i -> DefaultPayload.create("hello " + i)); + + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertThat(count).isEqualTo(20_000); + } + + @Test + @Timeout(60_000) + public void testChannel200_000() { + Flux payloads = Flux.range(1, 200_000).map(i -> DefaultPayload.create("hello " + i)); + + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertThat(count).isEqualTo(200_000); + } + + @Test + @Timeout(60_000) + @Disabled + public void testChannel2_000_000() { + AtomicInteger counter = new AtomicInteger(0); + + Flux payloads = Flux.range(1, 2_000_000).map(i -> DefaultPayload.create("hello " + i)); + long count = setup.getRSocket().requestChannel(payloads).count().block(); + + assertThat(count).isEqualTo(2_000_000); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java b/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java new file mode 100644 index 000000000..d065f3d71 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/ByteBufRepresentation.java @@ -0,0 +1,48 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.IllegalReferenceCountException; +import org.assertj.core.presentation.StandardRepresentation; + +public final class ByteBufRepresentation extends StandardRepresentation { + + @Override + protected String fallbackToStringOf(Object object) { + if (object instanceof ByteBuf) { + try { + String normalBufferString = object.toString(); + ByteBuf byteBuf = (ByteBuf) object; + if (byteBuf.readableBytes() <= 256) { + String prettyHexDump = ByteBufUtil.prettyHexDump(byteBuf); + return new StringBuilder() + .append(normalBufferString) + .append("\n") + .append(prettyHexDump) + .toString(); + } else { + return normalBufferString; + } + } catch (IllegalReferenceCountException e) { + // noops + } + } + + return super.fallbackToStringOf(object); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java new file mode 100644 index 000000000..1d6b7f69e --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/ClientSetupRule.java @@ -0,0 +1,81 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import io.rsocket.Closeable; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.function.Supplier; +import reactor.core.publisher.Mono; + +public class ClientSetupRule { + private static final String data = "hello world"; + private static final String metadata = "metadata"; + + private Supplier addressSupplier; + private BiFunction clientConnector; + private Function serverInit; + + private RSocket client; + private S server; + + public ClientSetupRule( + Supplier addressSupplier, + BiFunction clientTransportSupplier, + Function> serverTransportSupplier) { + this.addressSupplier = addressSupplier; + + this.serverInit = + address -> + RSocketServer.create((setup, rsocket) -> Mono.just(new TestRSocket(data, metadata))) + .bind(serverTransportSupplier.apply(address)) + .block(); + + this.clientConnector = + (address, server) -> + RSocketConnector.connectWith(clientTransportSupplier.apply(address, server)) + .doOnError(Throwable::printStackTrace) + .block(); + } + + public void init() { + T address = addressSupplier.get(); + S server = serverInit.apply(address); + client = clientConnector.apply(address, server); + } + + public void tearDown() { + server.dispose(); + } + + public RSocket getRSocket() { + return client; + } + + public String expectedPayloadData() { + return data; + } + + public String expectedPayloadMetadata() { + return metadata; + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/CountdownBaseSubscriber.java b/rsocket-test/src/main/java/io/rsocket/test/CountdownBaseSubscriber.java new file mode 100644 index 000000000..8fb948e9f --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/CountdownBaseSubscriber.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import io.rsocket.Payload; +import java.util.concurrent.CountDownLatch; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; + +class CountdownBaseSubscriber extends BaseSubscriber { + private CountDownLatch latch = new CountDownLatch(0); + private int count = 0; + + public void expect(int count) { + latch = new CountDownLatch((int) latch.getCount() + count); + if (upstream() != null) { + request(count); + } + } + + @Override + protected void hookOnNext(Payload value) { + count++; + latch.countDown(); + } + + @Override + protected void hookOnSubscribe(Subscription subscription) { + long count = latch.getCount(); + + if (count > 0) { + subscription.request(count); + } + } + + public void await() { + try { + latch.await(); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + public int count() { + return count; + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java new file mode 100644 index 000000000..46e807b09 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/LeaksTrackingByteBufAllocator.java @@ -0,0 +1,294 @@ +package io.rsocket.test; + +import static java.util.concurrent.locks.LockSupport.parkNanos; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import java.lang.reflect.Field; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import org.assertj.core.api.Assertions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Additional Utils which allows to decorate a ByteBufAllocator and track/assertOnLeaks all created + * ByteBuffs + */ +public class LeaksTrackingByteBufAllocator implements ByteBufAllocator { + static final Logger LOGGER = LoggerFactory.getLogger(LeaksTrackingByteBufAllocator.class); + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument(ByteBufAllocator allocator) { + return new LeaksTrackingByteBufAllocator(allocator, Duration.ZERO, ""); + } + + /** + * Allows to instrument any given the instance of ByteBufAllocator + * + * @param allocator + * @return + */ + public static LeaksTrackingByteBufAllocator instrument( + ByteBufAllocator allocator, Duration awaitZeroRefCntDuration, String tag) { + return new LeaksTrackingByteBufAllocator(allocator, awaitZeroRefCntDuration, tag); + } + + final ConcurrentLinkedQueue tracker = new ConcurrentLinkedQueue<>(); + + final ByteBufAllocator delegate; + + final Duration awaitZeroRefCntDuration; + + final String tag; + + private LeaksTrackingByteBufAllocator( + ByteBufAllocator delegate, Duration awaitZeroRefCntDuration, String tag) { + this.delegate = delegate; + this.awaitZeroRefCntDuration = awaitZeroRefCntDuration; + this.tag = tag; + } + + public LeaksTrackingByteBufAllocator assertHasNoLeaks() { + try { + ArrayList unreleased = new ArrayList<>(); + for (ByteBuf bb : tracker) { + if (bb.refCnt() != 0) { + unreleased.add(bb); + } + } + + final Duration awaitZeroRefCntDuration = this.awaitZeroRefCntDuration; + if (!unreleased.isEmpty() && !awaitZeroRefCntDuration.isZero()) { + final long startTime = System.currentTimeMillis(); + final long endTimeInMillis = startTime + awaitZeroRefCntDuration.toMillis(); + boolean hasUnreleased; + while (System.currentTimeMillis() <= endTimeInMillis) { + hasUnreleased = false; + for (ByteBuf bb : unreleased) { + if (bb.refCnt() != 0) { + hasUnreleased = true; + break; + } + } + + if (!hasUnreleased) { + return this; + } + + LOGGER.debug(tag + " await buffers to be released"); + for (int i = 0; i < 100; i++) { + System.gc(); + parkNanos(1000); + System.gc(); + } + } + } + + Set collected = new HashSet<>(); + for (ByteBuf buf : unreleased) { + if (buf.refCnt() != 0) { + try { + collected.add(buf); + } catch (IllegalReferenceCountException ignored) { + // fine to ignore if throws because of refCnt + } + } + } + + Assertions.assertThat( + collected + .stream() + .filter(bb -> bb.refCnt() != 0) + .peek( + bb -> { + try { + LOGGER.debug(tag + " " + resolveTrackingInfo(bb)); + } catch (Exception e) { + e.printStackTrace(); + } + })) + .describedAs("[" + tag + "] all buffers expected to be released but got ") + .isEmpty(); + } finally { + tracker.clear(); + } + return this; + } + + // Delegating logic with tracking of buffers + + @Override + public ByteBuf buffer() { + return track(delegate.buffer()); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return track(delegate.buffer(initialCapacity)); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return track(delegate.buffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf ioBuffer() { + return track(delegate.ioBuffer()); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return track(delegate.ioBuffer(initialCapacity)); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.ioBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf heapBuffer() { + return track(delegate.heapBuffer()); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return track(delegate.heapBuffer(initialCapacity)); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.heapBuffer(initialCapacity, maxCapacity)); + } + + @Override + public ByteBuf directBuffer() { + return track(delegate.directBuffer()); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return track(delegate.directBuffer(initialCapacity)); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return track(delegate.directBuffer(initialCapacity, maxCapacity)); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return track(delegate.compositeBuffer()); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return track(delegate.compositeBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return track(delegate.compositeHeapBuffer()); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return track(delegate.compositeHeapBuffer(maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return track(delegate.compositeDirectBuffer()); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return track(delegate.compositeDirectBuffer(maxNumComponents)); + } + + @Override + public boolean isDirectBufferPooled() { + return delegate.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return delegate.calculateNewCapacity(minNewCapacity, maxCapacity); + } + + T track(T buffer) { + tracker.offer(buffer); + + return buffer; + } + + static final Class simpleLeakAwareCompositeByteBufClass; + static final Field leakFieldForComposite; + static final Class simpleLeakAwareByteBufClass; + static final Field leakFieldForNormal; + static final Field allLeaksField; + + static { + try { + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareCompositeByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareCompositeByteBufClass = aClass; + leakFieldForComposite = leakField; + } + + { + final Class aClass = Class.forName("io.netty.buffer.SimpleLeakAwareByteBuf"); + final Field leakField = aClass.getDeclaredField("leak"); + + leakField.setAccessible(true); + + simpleLeakAwareByteBufClass = aClass; + leakFieldForNormal = leakField; + } + + { + final Class aClass = + Class.forName("io.netty.util.ResourceLeakDetector$DefaultResourceLeak"); + final Field field = aClass.getDeclaredField("allLeaks"); + + field.setAccessible(true); + + allLeaksField = field; + } + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @SuppressWarnings("unchecked") + static Set resolveTrackingInfo(ByteBuf byteBuf) throws Exception { + if (ResourceLeakDetector.getLevel().ordinal() + >= ResourceLeakDetector.Level.ADVANCED.ordinal()) { + if (simpleLeakAwareCompositeByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForComposite.get(byteBuf)); + } else if (simpleLeakAwareByteBufClass.isInstance(byteBuf)) { + return (Set) allLeaksField.get(leakFieldForNormal.get(byteBuf)); + } + } + + return Collections.emptySet(); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/PerfTest.java b/rsocket-test/src/main/java/io/rsocket/test/PerfTest.java new file mode 100644 index 000000000..3830ec1bc --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/PerfTest.java @@ -0,0 +1,17 @@ +package io.rsocket.test; + +import java.lang.annotation.*; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +/** + * {@code @PerfTest} is used to signal that the annotated test class or method is performance test, + * and is disabled unless enabled via setting the {@code TEST_PERF_ENABLED} environment variable to + * {@code true}. + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@EnabledIfEnvironmentVariable(named = "TEST_PERF_ENABLED", matches = "(?i)true") +@Test +public @interface PerfTest {} diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingClient.java b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java new file mode 100644 index 000000000..14740950a --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/PingClient.java @@ -0,0 +1,88 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.function.BiFunction; +import org.HdrHistogram.Recorder; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class PingClient { + + private final Payload payload; + private final Mono client; + + public PingClient(Mono client) { + this.client = client; + this.payload = ByteBufPayload.create("hello"); + } + + public Recorder startTracker(Duration interval) { + final Recorder histogram = new Recorder(3600000000000L, 3); + Flux.interval(interval) + .doOnNext( + aLong -> { + System.out.println("---- PING/ PONG HISTO ----"); + histogram + .getIntervalHistogram() + .outputPercentileDistribution(System.out, 5, 1000.0, false); + System.out.println("---- PING/ PONG HISTO ----"); + }) + .subscribe(); + return histogram; + } + + public Flux requestResponsePingPong(int count, final Recorder histogram) { + return pingPong(RSocket::requestResponse, count, histogram); + } + + public Flux requestStreamPingPong(int count, final Recorder histogram) { + return pingPong(RSocket::requestStream, count, histogram); + } + + Flux pingPong( + BiFunction> interaction, + int count, + final Recorder histogram) { + return Flux.usingWhen( + client, + rsocket -> + Flux.range(1, count) + .flatMap( + i -> { + long start = System.nanoTime(); + return Flux.from(interaction.apply(rsocket, payload.retain())) + .doOnNext(Payload::release) + .doFinally( + signalType -> { + long diff = System.nanoTime() - start; + histogram.recordValue(diff); + }); + }, + 64), + rsocket -> { + rsocket.dispose(); + return rsocket.onClose(); + }) + .doOnError(Throwable::printStackTrace); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java new file mode 100644 index 000000000..47f40a59d --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/PingHandler.java @@ -0,0 +1,59 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.util.ByteBufPayload; +import java.util.concurrent.ThreadLocalRandom; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class PingHandler implements SocketAcceptor { + + private final Payload pong; + + public PingHandler() { + byte[] data = new byte[1024]; + ThreadLocalRandom.current().nextBytes(data); + pong = ByteBufPayload.create(data); + } + + public PingHandler(byte[] data) { + pong = ByteBufPayload.create(data); + } + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + return Mono.just( + new RSocket() { + @Override + public Mono requestResponse(Payload payload) { + payload.release(); + return Mono.just(pong.retain()); + } + + @Override + public Flux requestStream(Payload payload) { + payload.release(); + return Flux.range(0, 100).map(v -> pong.retain()); + } + }); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/SlowTest.java b/rsocket-test/src/main/java/io/rsocket/test/SlowTest.java new file mode 100644 index 000000000..596cc0ffb --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/SlowTest.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +/** + * {@code @SlowTest} is used to signal that the annotated test class or test method is slow running + * and will be disabled unless enabled via setting the {@code TEST_SLOW_ENABLED} environment + * variable to {@code true}. + */ +@Target({ElementType.TYPE, ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +@Documented +@EnabledIfEnvironmentVariable(named = "TEST_SLOW_ENABLED", matches = "(?i)true") +@Test +public @interface SlowTest {} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java b/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java new file mode 100644 index 000000000..57a00e229 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TestDuplexConnection.java @@ -0,0 +1,166 @@ +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.PayloadFrameCodec; +import java.net.SocketAddress; +import java.util.function.BiFunction; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.Scannable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +public class TestDuplexConnection implements DuplexConnection { + + final ByteBufAllocator allocator; + final Sinks.Many inbound = Sinks.unsafe().many().unicast().onBackpressureError(); + final Sinks.Many outbound = Sinks.unsafe().many().unicast().onBackpressureError(); + final Sinks.One close = Sinks.one(); + + public TestDuplexConnection( + CoreSubscriber outboundSubscriber, boolean trackLeaks) { + this.outbound.asFlux().subscribe(outboundSubscriber); + this.allocator = + trackLeaks + ? LeaksTrackingByteBufAllocator.instrument(ByteBufAllocator.DEFAULT) + : ByteBufAllocator.DEFAULT; + } + + @Override + public void dispose() { + this.inbound.tryEmitComplete(); + this.outbound.tryEmitComplete(); + this.close.tryEmitEmpty(); + } + + @Override + public Mono onClose() { + return this.close.asMono(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException errorException) {} + + @Override + public Flux receive() { + return this.inbound + .asFlux() + .transform( + Operators.lift( + (BiFunction< + Scannable, + CoreSubscriber, + CoreSubscriber>) + ByteBufReleaserOperator::create)); + } + + @Override + public ByteBufAllocator alloc() { + return this.allocator; + } + + @Override + public SocketAddress remoteAddress() { + return new SocketAddress() { + @Override + public String toString() { + return "Test"; + } + }; + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + this.outbound.tryEmitNext(frame); + } + + public void sendPayloadFrame( + int streamId, ByteBuf data, @Nullable ByteBuf metadata, boolean complete) { + sendFrame( + streamId, + PayloadFrameCodec.encode(this.allocator, streamId, false, complete, true, metadata, data)); + } + + static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + static CoreSubscriber create( + Scannable scannable, CoreSubscriber actual) { + return new ByteBufReleaserOperator(actual); + } + + final CoreSubscriber actual; + + Subscription s; + + public ByteBufReleaserOperator(CoreSubscriber actual) { + this.actual = actual; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + this.actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + this.actual.onNext(buf); + buf.release(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java new file mode 100644 index 000000000..1e66abc5e --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TestFrames.java @@ -0,0 +1,108 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.rsocket.Payload; +import io.rsocket.frame.*; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; + +/** Test instances of all frame types. */ +public final class TestFrames { + private static final ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private static final Payload emptyPayload = DefaultPayload.create(Unpooled.EMPTY_BUFFER); + + private TestFrames() {} + + /** @return {@link ByteBuf} representing test instance of Cancel frame */ + public static ByteBuf createTestCancelFrame() { + return CancelFrameCodec.encode(allocator, 1); + } + + /** @return {@link ByteBuf} representing test instance of Error frame */ + public static ByteBuf createTestErrorFrame() { + return ErrorFrameCodec.encode(allocator, 1, new RuntimeException()); + } + + /** @return {@link ByteBuf} representing test instance of Extension frame */ + public static ByteBuf createTestExtensionFrame() { + return ExtensionFrameCodec.encode( + allocator, 1, 1, Unpooled.EMPTY_BUFFER, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Keep-Alive frame */ + public static ByteBuf createTestKeepaliveFrame() { + return KeepAliveFrameCodec.encode(allocator, false, 1, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Lease frame */ + public static ByteBuf createTestLeaseFrame() { + return LeaseFrameCodec.encode(allocator, 1, 1, null); + } + + /** @return {@link ByteBuf} representing test instance of Metadata-Push frame */ + public static ByteBuf createTestMetadataPushFrame() { + return MetadataPushFrameCodec.encode(allocator, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Payload frame */ + public static ByteBuf createTestPayloadFrame() { + return PayloadFrameCodec.encode(allocator, 1, false, true, false, null, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Request-Channel frame */ + public static ByteBuf createTestRequestChannelFrame() { + return RequestChannelFrameCodec.encode( + allocator, 1, false, false, 1, null, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Fire-and-Forget frame */ + public static ByteBuf createTestRequestFireAndForgetFrame() { + return RequestFireAndForgetFrameCodec.encode(allocator, 1, false, null, Unpooled.EMPTY_BUFFER); + } + + /** @return {@link ByteBuf} representing test instance of Request-N frame */ + public static ByteBuf createTestRequestNFrame() { + return RequestNFrameCodec.encode(allocator, 1, 1); + } + + /** @return {@link ByteBuf} representing test instance of Request-Response frame */ + public static ByteBuf createTestRequestResponseFrame() { + return RequestResponseFrameCodec.encodeReleasingPayload(allocator, 1, emptyPayload); + } + + /** @return {@link ByteBuf} representing test instance of Request-Stream frame */ + public static ByteBuf createTestRequestStreamFrame() { + return RequestStreamFrameCodec.encodeReleasingPayload(allocator, 1, 1L, emptyPayload); + } + + /** @return {@link ByteBuf} representing test instance of Setup frame */ + public static ByteBuf createTestSetupFrame() { + return SetupFrameCodec.encode( + allocator, + false, + 1, + 1, + Unpooled.EMPTY_BUFFER, + "metadataType", + "dataType", + EmptyPayload.INSTANCE); + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java new file mode 100644 index 000000000..1b294e394 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TestRSocket.java @@ -0,0 +1,109 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import static java.util.concurrent.locks.LockSupport.parkNanos; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.util.ByteBufPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicLong; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class TestRSocket implements RSocket { + private final String data; + private final String metadata; + + private final AtomicLong observedInteractions = new AtomicLong(); + private final AtomicLong activeInteractions = new AtomicLong(); + + public TestRSocket(String data, String metadata) { + this.data = data; + this.metadata = metadata; + } + + @Override + public Mono requestResponse(Payload payload) { + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Mono.just(ByteBufPayload.create(data, metadata)) + .doFinally(__ -> activeInteractions.getAndDecrement()); + } + + @Override + public Flux requestStream(Payload payload) { + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Flux.range(1, 10_000) + .map(l -> ByteBufPayload.create(data, metadata)) + .doFinally(__ -> activeInteractions.getAndDecrement()); + } + + @Override + public Mono metadataPush(Payload payload) { + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Mono.empty().doFinally(__ -> activeInteractions.getAndDecrement()); + } + + @Override + public Mono fireAndForget(Payload payload) { + activeInteractions.getAndIncrement(); + payload.release(); + observedInteractions.getAndIncrement(); + return Mono.empty().doFinally(__ -> activeInteractions.getAndDecrement()); + } + + @Override + public Flux requestChannel(Publisher payloads) { + activeInteractions.getAndIncrement(); + observedInteractions.getAndIncrement(); + return Flux.from(payloads).doFinally(__ -> activeInteractions.getAndDecrement()); + } + + public boolean awaitAllInteractionTermination(Duration duration) { + long end = duration.plusNanos(System.nanoTime()).toNanos(); + long activeNow; + while ((activeNow = activeInteractions.get()) > 0) { + if (System.nanoTime() >= end) { + return false; + } + parkNanos(100); + } + + return activeNow == 0; + } + + public boolean awaitUntilObserved(int interactions, Duration duration) { + long end = System.nanoTime() + duration.toNanos(); + long observed; + while ((observed = observedInteractions.get()) < interactions) { + if (System.nanoTime() >= end) { + return false; + } + parkNanos(100); + } + + return observed >= interactions; + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TestSubscriber.java b/rsocket-test/src/main/java/io/rsocket/test/TestSubscriber.java new file mode 100644 index 000000000..62b6c242b --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TestSubscriber.java @@ -0,0 +1,67 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; + +import io.rsocket.Payload; +import org.mockito.Mockito; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +public class TestSubscriber { + public static Subscriber create() { + return create(Long.MAX_VALUE); + } + + public static Subscriber create(long initialRequest) { + @SuppressWarnings("unchecked") + Subscriber mock = mock(Subscriber.class); + + Mockito.doAnswer( + invocation -> { + if (initialRequest > 0) { + ((Subscription) invocation.getArguments()[0]).request(initialRequest); + } + return null; + }) + .when(mock) + .onSubscribe(any(Subscription.class)); + + return mock; + } + + public static Payload anyPayload() { + return any(Payload.class); + } + + public static Subscriber createCancelling() { + @SuppressWarnings("unchecked") + Subscriber mock = mock(Subscriber.class); + + Mockito.doAnswer( + invocation -> { + ((Subscription) invocation.getArguments()[0]).cancel(); + return null; + }) + .when(mock) + .onSubscribe(any(Subscription.class)); + + return mock; + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java new file mode 100644 index 000000000..1fcca97db --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TransportTest.java @@ -0,0 +1,984 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.ResourceLeakDetector; +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.RSocketErrorException; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.plugins.DuplexConnectionInterceptor; +import io.rsocket.resume.InMemoryResumableFramesStore; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.util.ByteBufPayload; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.SocketAddress; +import java.time.Duration; +import java.util.Arrays; +import java.util.concurrent.CancellationException; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.BiFunction; +import java.util.function.Predicate; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.zip.GZIPInputStream; +import org.assertj.core.api.Assertions; +import org.assertj.core.api.Assumptions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Disposable; +import reactor.core.Disposables; +import reactor.core.Exceptions; +import reactor.core.Fuseable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Hooks; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; +import reactor.core.publisher.Sinks; +import reactor.core.scheduler.Scheduler; +import reactor.core.scheduler.Schedulers; +import reactor.test.StepVerifier; +import reactor.util.Logger; +import reactor.util.Loggers; + +public interface TransportTest { + + Logger logger = Loggers.getLogger(TransportTest.class); + + String MOCK_DATA = "test-data"; + String MOCK_METADATA = "metadata"; + String LARGE_DATA = read("words.shakespeare.txt.gz"); + Payload LARGE_PAYLOAD = ByteBufPayload.create(LARGE_DATA, LARGE_DATA); + + static String read(String resourceName) { + try (BufferedReader br = + new BufferedReader( + new InputStreamReader( + new GZIPInputStream( + TransportTest.class.getClassLoader().getResourceAsStream(resourceName))))) { + + return br.lines().map(String::toLowerCase).collect(Collectors.joining("\n\r")); + } catch (Throwable e) { + throw new RuntimeException(e); + } + } + + @BeforeEach + default void setup() { + Hooks.onOperatorDebug(); + } + + @AfterEach + default void close() { + try { + logger.debug("------------------Awaiting communication to finish------------------"); + getTransportPair().responder.awaitAllInteractionTermination(getTimeout()); + logger.debug("---------------------Disposing Client And Server--------------------"); + getTransportPair().dispose(); + getTransportPair().awaitClosed(getTimeout()); + logger.debug("------------------------Disposing Schedulers-------------------------"); + Schedulers.parallel().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + Schedulers.boundedElastic().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + Schedulers.single().disposeGracefully().timeout(getTimeout(), Mono.empty()).block(); + logger.debug("---------------------------Leaks Checking----------------------------"); + RuntimeException throwable = + new RuntimeException() { + @Override + public synchronized Throwable fillInStackTrace() { + return this; + } + + @Override + public String getMessage() { + return Arrays.toString(getSuppressed()); + } + }; + + try { + getTransportPair().byteBufAllocator2.assertHasNoLeaks(); + } catch (Throwable t) { + throwable = Exceptions.addSuppressed(throwable, t); + } + + try { + getTransportPair().byteBufAllocator1.assertHasNoLeaks(); + } catch (Throwable t) { + throwable = Exceptions.addSuppressed(throwable, t); + } + + if (throwable.getSuppressed().length > 0) { + throw throwable; + } + } finally { + Hooks.resetOnOperatorDebug(); + Schedulers.resetOnHandleError(); + } + } + + default Payload createTestPayload(int metadataPresent) { + String metadata1; + + switch (metadataPresent % 5) { + case 0: + metadata1 = null; + break; + case 1: + metadata1 = ""; + break; + default: + metadata1 = MOCK_METADATA; + break; + } + String metadata = metadata1; + + return ByteBufPayload.create(MOCK_DATA, metadata); + } + + @DisplayName("makes 10 fireAndForget requests") + @Test + default void fireAndForget10() { + Flux.range(1, 10) + .flatMap(i -> getClient().fireAndForget(createTestPayload(i))) + .as(StepVerifier::create) + .expectComplete() + .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); + } + + @DisplayName("makes 10 fireAndForget with Large Payload in Requests") + @Test + default void largePayloadFireAndForget10() { + Flux.range(1, 10) + .flatMap(i -> getClient().fireAndForget(LARGE_PAYLOAD.retain())) + .as(StepVerifier::create) + .expectComplete() + .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); + } + + default RSocket getClient() { + return getTransportPair().getClient(); + } + + Duration getTimeout(); + + TransportPair getTransportPair(); + + @DisplayName("makes 10 metadataPush requests") + @Test + default void metadataPush10() { + Assumptions.assumeThat(getTransportPair().withResumability).isFalse(); + Flux.range(1, 10) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", "test-metadata"))) + .as(StepVerifier::create) + .expectComplete() + .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); + } + + @DisplayName("makes 10 metadataPush with Large Metadata in requests") + @Test + default void largePayloadMetadataPush10() { + Assumptions.assumeThat(getTransportPair().withResumability).isFalse(); + Flux.range(1, 10) + .flatMap(i -> getClient().metadataPush(ByteBufPayload.create("", LARGE_DATA))) + .as(StepVerifier::create) + .expectComplete() + .verify(getTimeout()); + + getTransportPair().responder.awaitUntilObserved(10, getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 0 payloads") + @Test + default void requestChannel0() { + getClient() + .requestChannel(Flux.empty()) + .as(StepVerifier::create) + .expectErrorSatisfies( + t -> + Assertions.assertThat(t) + .isInstanceOf(CancellationException.class) + .hasMessage("Empty Source")) + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 1 payloads") + @Test + default void requestChannel1() { + getClient() + .requestChannel(Mono.just(createTestPayload(0))) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(1)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 200,000 payloads") + @Test + default void requestChannel200_000() { + Flux payloads = Flux.range(0, 200_000).map(this::createTestPayload); + + getClient() + .requestChannel(payloads) + .doOnNext(Payload::release) + .limitRate(8) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(200_000)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 50 large payloads") + @Test + default void largePayloadRequestChannel50() { + Flux payloads = Flux.range(0, 50).map(__ -> LARGE_PAYLOAD.retain()); + + getClient() + .requestChannel(payloads) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(50)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 20,000 payloads") + @Test + default void requestChannel20_000() { + Flux payloads = Flux.range(0, 20_000).map(metadataPresent -> createTestPayload(7)); + + getClient() + .requestChannel(payloads) + .doOnNext(this::assertChannelPayload) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(20_000)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 2,000,000 payloads") + @SlowTest + default void requestChannel2_000_000() { + Flux payloads = Flux.range(0, 2_000_000).map(this::createTestPayload); + + getClient() + .requestChannel(payloads) + .doOnNext(Payload::release) + .limitRate(8) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(2_000_000)) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestChannel request with 3 payloads") + @Test + default void requestChannel3() { + AtomicLong requested = new AtomicLong(); + Flux payloads = + Flux.range(0, 3).doOnRequest(requested::addAndGet).map(this::createTestPayload); + + getClient() + .requestChannel(payloads) + .doOnNext(Payload::release) + .as(publisher -> StepVerifier.create(publisher, 3)) + .thenConsumeWhile(new PayloadPredicate(3)) + .expectComplete() + .verify(getTimeout()); + + Assertions.assertThat(requested.get()).isEqualTo(3L); + } + + @DisplayName("makes 1 requestChannel request with 256 payloads") + @Test + default void requestChannel256() { + AtomicInteger counter = new AtomicInteger(); + Flux payloads = + Flux.defer( + () -> { + final int subscription = counter.getAndIncrement(); + return Flux.range(0, 256) + .map(i -> "S{" + subscription + "}: Data{" + i + "}") + .map(data -> ByteBufPayload.create(data)); + }); + final Scheduler scheduler = Schedulers.fromExecutorService(Executors.newFixedThreadPool(12)); + + try { + Flux.range(0, 1024) + .flatMap(v -> Mono.fromRunnable(() -> check(payloads)).subscribeOn(scheduler), 12) + .blockLast(); + } finally { + scheduler.disposeGracefully().block(); + } + } + + default void check(Flux payloads) { + getClient() + .requestChannel(payloads) + .doOnNext(ReferenceCounted::release) + .limitRate(8) + .as(StepVerifier::create) + .thenConsumeWhile(new PayloadPredicate(256)) + .as("expected 256 items") + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestResponse request") + @Test + default void requestResponse1() { + getClient() + .requestResponse(createTestPayload(1)) + .doOnNext(this::assertPayload) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .expectNextCount(1) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 10 requestResponse requests") + @Test + default void requestResponse10() { + Flux.range(1, 10) + .flatMap( + i -> + getClient() + .requestResponse(createTestPayload(i)) + .doOnNext(v -> assertPayload(v)) + .doOnNext(Payload::release)) + .as(StepVerifier::create) + .expectNextCount(10) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 100 requestResponse requests") + @Test + default void requestResponse100() { + Flux.range(1, 100) + .flatMap(i -> getClient().requestResponse(createTestPayload(i)).doOnNext(Payload::release)) + .as(StepVerifier::create) + .expectNextCount(100) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 50 requestResponse requests") + @Test + default void largePayloadRequestResponse50() { + Flux.range(1, 50) + .flatMap( + i -> getClient().requestResponse(LARGE_PAYLOAD.retain()).doOnNext(Payload::release)) + .as(StepVerifier::create) + .expectNextCount(50) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 10,000 requestResponse requests") + @Test + default void requestResponse10_000() { + Flux.range(1, 10_000) + .flatMap(i -> getClient().requestResponse(createTestPayload(i)).doOnNext(Payload::release)) + .as(StepVerifier::create) + .expectNextCount(10_000) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestStream request and receives 10,000 responses") + @Test + default void requestStream10_000() { + getClient() + .requestStream(createTestPayload(3)) + .doOnNext(this::assertPayload) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .expectNextCount(10_000) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestStream request and receives 5 responses") + @Test + default void requestStream5() { + getClient() + .requestStream(createTestPayload(3)) + .doOnNext(this::assertPayload) + .doOnNext(Payload::release) + .take(5) + .as(StepVerifier::create) + .expectNextCount(5) + .expectComplete() + .verify(getTimeout()); + } + + @DisplayName("makes 1 requestStream request and consumes result incrementally") + @Test + default void requestStreamDelayedRequestN() { + getClient() + .requestStream(createTestPayload(3)) + .take(10) + .doOnNext(Payload::release) + .as(StepVerifier::create) + .thenRequest(5) + .expectNextCount(5) + .thenRequest(5) + .expectNextCount(5) + .expectComplete() + .verify(getTimeout()); + } + + default void assertPayload(Payload p) { + TransportPair transportPair = getTransportPair(); + if (!transportPair.expectedPayloadData().equals(p.getDataUtf8()) + || !transportPair.expectedPayloadMetadata().equals(p.getMetadataUtf8())) { + throw new IllegalStateException("Unexpected payload"); + } + } + + default void assertChannelPayload(Payload p) { + if (!MOCK_DATA.equals(p.getDataUtf8()) || !MOCK_METADATA.equals(p.getMetadataUtf8())) { + throw new IllegalStateException("Unexpected payload"); + } + } + + class TransportPair implements Disposable { + + private static final String data = "hello world"; + private static final String metadata = "metadata"; + + private final boolean withResumability; + private final boolean runClientWithAsyncInterceptors; + private final boolean runServerWithAsyncInterceptors; + + private final LeaksTrackingByteBufAllocator byteBufAllocator1 = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofMinutes(1), "Client"); + private final LeaksTrackingByteBufAllocator byteBufAllocator2 = + LeaksTrackingByteBufAllocator.instrument( + ByteBufAllocator.DEFAULT, Duration.ofMinutes(1), "Server"); + + private final TestRSocket responder; + + private final RSocket client; + + private final S server; + + public TransportPair( + Supplier addressSupplier, + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier) { + this(addressSupplier, clientTransportSupplier, serverTransportSupplier, false); + } + + public TransportPair( + Supplier addressSupplier, + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier, + boolean withRandomFragmentation) { + this( + addressSupplier, + clientTransportSupplier, + serverTransportSupplier, + withRandomFragmentation, + false); + } + + public TransportPair( + Supplier addressSupplier, + TriFunction clientTransportSupplier, + BiFunction> serverTransportSupplier, + boolean withRandomFragmentation, + boolean withResumability) { + Schedulers.onHandleError((t, e) -> e.printStackTrace()); + Schedulers.resetFactory(); + + this.withResumability = withResumability; + + T address = addressSupplier.get(); + + this.runClientWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + this.runServerWithAsyncInterceptors = ThreadLocalRandom.current().nextBoolean(); + + ByteBufAllocator allocatorToSupply1; + ByteBufAllocator allocatorToSupply2; + if (ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.ADVANCED + || ResourceLeakDetector.getLevel() == ResourceLeakDetector.Level.PARANOID) { + logger.info("Using LeakTrackingByteBufAllocator"); + allocatorToSupply1 = byteBufAllocator1; + allocatorToSupply2 = byteBufAllocator2; + } else { + allocatorToSupply1 = ByteBufAllocator.DEFAULT; + allocatorToSupply2 = ByteBufAllocator.DEFAULT; + } + responder = new TestRSocket(TransportPair.data, metadata); + final RSocketServer rSocketServer = + RSocketServer.create((setup, sendingSocket) -> Mono.just(responder)) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .interceptors( + registry -> { + if (runServerWithAsyncInterceptors && !withResumability) { + logger.info( + "Perform Integration Test with Async Interceptors Enabled For Server"); + registry + .forConnection( + (type, duplexConnection) -> + new AsyncDuplexConnection(duplexConnection, "server")) + .forSocketAcceptor( + delegate -> + (connectionSetupPayload, sendingSocket) -> + delegate + .accept(connectionSetupPayload, sendingSocket) + .subscribeOn(Schedulers.parallel())); + } + + if (withResumability) { + registry.forConnection( + (type, duplexConnection) -> + type == DuplexConnectionInterceptor.Type.SOURCE + ? new DisconnectingDuplexConnection( + "Server", + duplexConnection, + Duration.ofMillis( + ThreadLocalRandom.current().nextInt(100, 1000))) + : duplexConnection); + } + }); + + if (withResumability) { + rSocketServer.resume( + new Resume() + .storeFactory( + token -> new InMemoryResumableFramesStore("server", token, Integer.MAX_VALUE))); + } + + if (withRandomFragmentation) { + rSocketServer.fragment(ThreadLocalRandom.current().nextInt(256, 512)); + } + + server = + rSocketServer.bind(serverTransportSupplier.apply(address, allocatorToSupply2)).block(); + + final RSocketConnector rSocketConnector = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMillis(10), Duration.ofHours(1)) + .interceptors( + registry -> { + if (runClientWithAsyncInterceptors && !withResumability) { + logger.info( + "Perform Integration Test with Async Interceptors Enabled For Client"); + registry + .forConnection( + (type, duplexConnection) -> + new AsyncDuplexConnection(duplexConnection, "client")) + .forSocketAcceptor( + delegate -> + (connectionSetupPayload, sendingSocket) -> + delegate + .accept(connectionSetupPayload, sendingSocket) + .subscribeOn(Schedulers.parallel())); + } + + if (withResumability) { + registry.forConnection( + (type, duplexConnection) -> + type == DuplexConnectionInterceptor.Type.SOURCE + ? new DisconnectingDuplexConnection( + "Client", + duplexConnection, + Duration.ofMillis( + ThreadLocalRandom.current().nextInt(10, 1500))) + : duplexConnection); + } + }); + + if (withResumability) { + rSocketConnector.resume( + new Resume() + .storeFactory( + token -> new InMemoryResumableFramesStore("client", token, Integer.MAX_VALUE))); + } + + if (withRandomFragmentation) { + rSocketConnector.fragment(ThreadLocalRandom.current().nextInt(256, 512)); + } + + client = + rSocketConnector + .connect(clientTransportSupplier.apply(address, server, allocatorToSupply1)) + .doOnError(Throwable::printStackTrace) + .block(); + } + + @Override + public void dispose() { + logger.info("terminating transport pair"); + client.dispose(); + } + + RSocket getClient() { + return client; + } + + public String expectedPayloadData() { + return data; + } + + public String expectedPayloadMetadata() { + return metadata; + } + + public void awaitClosed(Duration timeout) { + logger.info("awaiting termination of transport pair"); + logger.info( + "wrappers combination: client{async=" + + runClientWithAsyncInterceptors + + "; resume=" + + withResumability + + "} server{async=" + + runServerWithAsyncInterceptors + + "; resume=" + + withResumability + + "}"); + client + .onClose() + .doOnSubscribe(s -> logger.info("Client termination stage=onSubscribe(" + s + ")")) + .doOnEach(s -> logger.info("Client termination stage=" + s)) + .onErrorResume(t -> Mono.empty()) + .doOnTerminate(() -> logger.info("Client terminated. Terminating Server")) + .then(Mono.fromRunnable(server::dispose)) + .then( + server + .onClose() + .doOnSubscribe( + s -> logger.info("Server termination stage=onSubscribe(" + s + ")")) + .doOnEach(s -> logger.info("Server termination stage=" + s))) + .onErrorResume(t -> Mono.empty()) + .block(timeout); + + logger.info("TransportPair has been terminated"); + } + + private static class AsyncDuplexConnection implements DuplexConnection { + + private final DuplexConnection duplexConnection; + private String tag; + private final ByteBufReleaserOperator bufReleaserOperator; + + public AsyncDuplexConnection(DuplexConnection duplexConnection, String tag) { + this.duplexConnection = duplexConnection; + this.tag = tag; + this.bufReleaserOperator = new ByteBufReleaserOperator(); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + duplexConnection.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + duplexConnection.sendErrorAndClose(e); + } + + @Override + public Flux receive() { + return duplexConnection + .receive() + .doOnTerminate(() -> logger.info("[" + this + "] Receive is done before PO")) + .subscribeOn(Schedulers.boundedElastic()) + .doOnNext(ByteBuf::retain) + .publishOn(Schedulers.boundedElastic(), Integer.MAX_VALUE) + .doOnTerminate(() -> logger.info("[" + this + "] Receive is done after PO")) + .doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::safeRelease) + .transform( + Operators.lift( + (__, actual) -> { + bufReleaserOperator.actual = actual; + return bufReleaserOperator; + })); + } + + @Override + public ByteBufAllocator alloc() { + return duplexConnection.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return duplexConnection.remoteAddress(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError( + duplexConnection + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] Source Connection is done")), + bufReleaserOperator + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] BufferReleaser is done"))); + } + + @Override + public void dispose() { + duplexConnection.dispose(); + } + + @Override + public String toString() { + return "AsyncDuplexConnection{" + + "duplexConnection=" + + duplexConnection + + ", tag='" + + tag + + '\'' + + ", bufReleaserOperator=" + + bufReleaserOperator + + '}'; + } + } + + private static class DisconnectingDuplexConnection implements DuplexConnection { + + private final String tag; + final DuplexConnection source; + final Duration delay; + final Disposable.Swap disposables = Disposables.swap(); + + DisconnectingDuplexConnection(String tag, DuplexConnection source, Duration delay) { + this.tag = tag; + this.source = source; + this.delay = delay; + } + + @Override + public void dispose() { + disposables.dispose(); + source.dispose(); + } + + @Override + public Mono onClose() { + return source + .onClose() + .doOnTerminate(() -> logger.info("[" + this + "] Source Connection is done")); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + source.sendFrame(streamId, frame); + } + + @Override + public void sendErrorAndClose(RSocketErrorException errorException) { + source.sendErrorAndClose(errorException); + } + + boolean receivedFirst; + + @Override + public Flux receive() { + return source + .receive() + .doOnSubscribe( + __ -> logger.warn("Tag {}. Subscribing Connection[{}]", tag, source.hashCode())) + .doOnNext( + bb -> { + if (!receivedFirst) { + receivedFirst = true; + disposables.replace( + Mono.delay(delay) + .takeUntilOther(source.onClose()) + .subscribe( + __ -> { + logger.warn( + "Tag {}. Disposing Connection[{}]", tag, source.hashCode()); + source.dispose(); + })); + } + }); + } + + @Override + public ByteBufAllocator alloc() { + return source.alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return source.remoteAddress(); + } + + @Override + public String toString() { + return "DisconnectingDuplexConnection{" + + "tag='" + + tag + + '\'' + + ", source=" + + source + + ", disposables=" + + disposables + + '}'; + } + } + + private static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + CoreSubscriber actual; + final Sinks.Empty closeableMonoSink; + + Subscription s; + + public ByteBufReleaserOperator() { + this.closeableMonoSink = Sinks.unsafe().empty(); + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + try { + actual.onNext(buf); + } finally { + buf.release(); + } + } + + Mono onClose() { + return closeableMonoSink.asMono(); + } + + @Override + public void onError(Throwable t) { + actual.onError(t); + closeableMonoSink.tryEmitError(t); + } + + @Override + public void onComplete() { + actual.onComplete(); + closeableMonoSink.tryEmitEmpty(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + closeableMonoSink.tryEmitEmpty(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public String toString() { + return "ByteBufReleaserOperator{" + + "isActualPresent=" + + (actual != null) + + ", " + + "isSubscriptionPresent=" + + (s != null) + + '}'; + } + } + } + + class PayloadPredicate implements Predicate { + final int expectedCnt; + int cnt; + + public PayloadPredicate(int expectedCnt) { + this.expectedCnt = expectedCnt; + } + + @Override + public boolean test(Payload p) { + boolean shouldConsume = cnt++ < expectedCnt; + if (!shouldConsume) { + logger.info( + "Metadata: \n\r{}\n\rData:{}", + p.hasMetadata() + ? new ByteBufRepresentation().fallbackToStringOf(p.sliceMetadata()) + : "Empty", + new ByteBufRepresentation().fallbackToStringOf(p.sliceData())); + } + return shouldConsume; + } + } +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java b/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java new file mode 100644 index 000000000..87a1d4dbf --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/TriFunction.java @@ -0,0 +1,6 @@ +package io.rsocket.test; + +@FunctionalInterface +public interface TriFunction { + R apply(T1 t1, T2 t2, T3 t3); +} diff --git a/rsocket-test/src/main/java/io/rsocket/test/package-info.java b/rsocket-test/src/main/java/io/rsocket/test/package-info.java new file mode 100644 index 000000000..600ac2b82 --- /dev/null +++ b/rsocket-test/src/main/java/io/rsocket/test/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** Utilities for testing RSocket components. */ +@NonNullApi +package io.rsocket.test; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation b/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation new file mode 100644 index 000000000..0c33b5ff7 --- /dev/null +++ b/rsocket-test/src/main/resources/META-INF/services/org.assertj.core.presentation.Representation @@ -0,0 +1,16 @@ +# +# Copyright 2015-2018 the original author or authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +io.rsocket.test.ByteBufRepresentation \ No newline at end of file diff --git a/rsocket-test/src/main/resources/words.shakespeare.txt.gz b/rsocket-test/src/main/resources/words.shakespeare.txt.gz new file mode 100644 index 000000000..422a4b331 Binary files /dev/null and b/rsocket-test/src/main/resources/words.shakespeare.txt.gz differ diff --git a/rsocket-transport-local/build.gradle b/rsocket-transport-local/build.gradle new file mode 100644 index 000000000..fc32125e2 --- /dev/null +++ b/rsocket-transport-local/build.gradle @@ -0,0 +1,41 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' +} + +dependencies { + api project(':rsocket-core') + + testImplementation project(':rsocket-test') + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.assertj:assertj-core' + testImplementation 'org.junit.jupiter:junit-jupiter-api' + + testRuntimeOnly 'ch.qos.logback:logback-classic' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' +} + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.transport.local") + } +} + +description = 'Local RSocket transport implementation' diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java new file mode 100644 index 000000000..1b3779e85 --- /dev/null +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalClientTransport.java @@ -0,0 +1,93 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.internal.UnboundedProcessor; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import java.util.Objects; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} in the + * same JVM. + */ +public final class LocalClientTransport implements ClientTransport { + + private final String name; + + private final ByteBufAllocator allocator; + + private LocalClientTransport(String name, ByteBufAllocator allocator) { + this.name = name; + this.allocator = allocator; + } + + /** + * Creates a new instance. + * + * @param name the name of the {@link ClientTransport} instance to connect to + * @return a new instance + * @throws NullPointerException if {@code name} is {@code null} + */ + public static LocalClientTransport create(String name) { + Objects.requireNonNull(name, "name must not be null"); + + return create(name, ByteBufAllocator.DEFAULT); + } + + /** + * Creates a new instance. + * + * @param name the name of the {@link ClientTransport} instance to connect to + * @param allocator the allocator used by {@link ClientTransport} instance + * @return a new instance + * @throws NullPointerException if {@code name} is {@code null} + */ + public static LocalClientTransport create(String name, ByteBufAllocator allocator) { + Objects.requireNonNull(name, "name must not be null"); + Objects.requireNonNull(allocator, "allocator must not be null"); + + return new LocalClientTransport(name, allocator); + } + + @Override + public Mono connect() { + return Mono.defer( + () -> { + ServerTransport.ConnectionAcceptor server = LocalServerTransport.findServer(name); + if (server == null) { + return Mono.error(new IllegalArgumentException("Could not find server: " + name)); + } + + Sinks.One inSink = Sinks.one(); + Sinks.One outSink = Sinks.one(); + UnboundedProcessor in = new UnboundedProcessor(inSink::tryEmitEmpty); + UnboundedProcessor out = new UnboundedProcessor(outSink::tryEmitEmpty); + + Mono onClose = inSink.asMono().and(outSink.asMono()); + + server.apply(new LocalDuplexConnection(name, allocator, out, in, onClose)).subscribe(); + + return Mono.just( + new LocalDuplexConnection(name, allocator, in, out, onClose)); + }); + } +} diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java new file mode 100644 index 000000000..c1d0fd2a3 --- /dev/null +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalDuplexConnection.java @@ -0,0 +1,198 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.internal.UnboundedProcessor; +import java.net.SocketAddress; +import java.util.Objects; +import org.reactivestreams.Subscription; +import reactor.core.CoreSubscriber; +import reactor.core.Fuseable; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Operators; + +/** An implementation of {@link DuplexConnection} that connects inside the same JVM. */ +final class LocalDuplexConnection implements DuplexConnection { + + private final LocalSocketAddress address; + private final ByteBufAllocator allocator; + private final UnboundedProcessor in; + + private final Mono onClose; + + private final UnboundedProcessor out; + + /** + * Creates a new instance. + * + * @param name the name assigned to this local connection + * @param in the inbound {@link ByteBuf}s + * @param out the outbound {@link ByteBuf}s + * @param onClose the closing notifier + * @throws NullPointerException if {@code in}, {@code out}, or {@code onClose} are {@code null} + */ + LocalDuplexConnection( + String name, + ByteBufAllocator allocator, + UnboundedProcessor in, + UnboundedProcessor out, + Mono onClose) { + this.address = new LocalSocketAddress(name); + this.allocator = Objects.requireNonNull(allocator, "allocator must not be null"); + this.in = Objects.requireNonNull(in, "in must not be null"); + this.out = Objects.requireNonNull(out, "out must not be null"); + this.onClose = Objects.requireNonNull(onClose, "onClose must not be null"); + } + + @Override + public void dispose() { + out.onComplete(); + } + + @Override + public boolean isDisposed() { + return out.isDisposed(); + } + + @Override + public Mono onClose() { + return onClose; + } + + @Override + public Flux receive() { + return in.transform( + Operators.lift( + (__, actual) -> new ByteBufReleaserOperator(actual, this))); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + if (streamId == 0) { + out.tryEmitPrioritized(frame); + } else { + out.tryEmitNormal(frame); + } + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(allocator, 0, e); + out.tryEmitFinal(errorFrame); + } + + @Override + public ByteBufAllocator alloc() { + return allocator; + } + + @Override + public SocketAddress remoteAddress() { + return address; + } + + @Override + public String toString() { + return "LocalDuplexConnection{" + "address=" + address + "hash=" + hashCode() + '}'; + } + + static class ByteBufReleaserOperator + implements CoreSubscriber, Subscription, Fuseable.QueueSubscription { + + final CoreSubscriber actual; + final LocalDuplexConnection parent; + + Subscription s; + + public ByteBufReleaserOperator( + CoreSubscriber actual, LocalDuplexConnection parent) { + this.actual = actual; + this.parent = parent; + } + + @Override + public void onSubscribe(Subscription s) { + if (Operators.validate(this.s, s)) { + this.s = s; + actual.onSubscribe(this); + } + } + + @Override + public void onNext(ByteBuf buf) { + try { + actual.onNext(buf); + } finally { + buf.release(); + } + } + + @Override + public void onError(Throwable t) { + parent.out.onError(t); + actual.onError(t); + } + + @Override + public void onComplete() { + parent.out.onComplete(); + actual.onComplete(); + } + + @Override + public void request(long n) { + s.request(n); + } + + @Override + public void cancel() { + s.cancel(); + parent.out.onComplete(); + } + + @Override + public int requestFusion(int requestedMode) { + return Fuseable.NONE; + } + + @Override + public ByteBuf poll() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public int size() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + + @Override + public void clear() { + throw new UnsupportedOperationException(NOT_SUPPORTED_MESSAGE); + } + } +} diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java new file mode 100644 index 000000000..975cb6793 --- /dev/null +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalServerTransport.java @@ -0,0 +1,178 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.Closeable; +import io.rsocket.DuplexConnection; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import java.util.Objects; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.util.annotation.Nullable; + +/** + * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} in the + * same JVM. + */ +public final class LocalServerTransport implements ServerTransport { + + private static final ConcurrentMap registry = + new ConcurrentHashMap<>(); + + private final String name; + + private LocalServerTransport(String name) { + this.name = name; + } + + /** + * Creates an instance. + * + * @param name the name of this {@link ServerTransport} that clients will connect to + * @return a new instance + * @throws NullPointerException if {@code name} is {@code null} + */ + public static LocalServerTransport create(String name) { + Objects.requireNonNull(name, "name must not be null"); + return new LocalServerTransport(name); + } + + /** + * Creates an instance with a random name. + * + * @return a new instance with a random name + */ + public static LocalServerTransport createEphemeral() { + return create(UUID.randomUUID().toString()); + } + + /** + * Remove an instance from the JVM registry. + * + * @param name the local transport instance to free + * @throws NullPointerException if {@code name} is {@code null} + */ + public static void dispose(String name) { + Objects.requireNonNull(name, "name must not be null"); + ServerCloseableAcceptor sca = registry.remove(name); + if (sca != null) { + sca.dispose(); + } + } + + /** + * Retrieves an instance of {@link ConnectionAcceptor} based on the name of its {@code + * LocalServerTransport}. Returns {@code null} if that server is not registered. + * + * @param name the name of the server to retrieve + * @return the server if it has been registered, {@code null} otherwise + * @throws NullPointerException if {@code name} is {@code null} + */ + static @Nullable ConnectionAcceptor findServer(String name) { + Objects.requireNonNull(name, "name must not be null"); + + return registry.get(name); + } + + /** Return the name associated with this local server instance. */ + String getName() { + return name; + } + + /** + * Return a new {@link LocalClientTransport} connected to this {@code LocalServerTransport} + * through its {@link #getName()}. + */ + public LocalClientTransport clientTransport() { + return LocalClientTransport.create(name); + } + + @Override + public Mono start(ConnectionAcceptor acceptor) { + Objects.requireNonNull(acceptor, "acceptor must not be null"); + return Mono.create( + sink -> { + ServerCloseableAcceptor closeable = new ServerCloseableAcceptor(name, acceptor); + if (registry.putIfAbsent(name, closeable) != null) { + sink.error(new IllegalStateException("name already registered: " + name)); + } + sink.success(closeable); + }); + } + + @SuppressWarnings({"ReactorTransformationOnMonoVoid", "CallingSubscribeInNonBlockingScope"}) + static class ServerCloseableAcceptor implements ConnectionAcceptor, Closeable { + + private final LocalSocketAddress address; + + private final ConnectionAcceptor acceptor; + + private final Set activeConnections = ConcurrentHashMap.newKeySet(); + + private final Sinks.Empty onClose = Sinks.unsafe().empty(); + + ServerCloseableAcceptor(String name, ConnectionAcceptor acceptor) { + Objects.requireNonNull(name, "name must not be null"); + this.address = new LocalSocketAddress(name); + this.acceptor = acceptor; + } + + @Override + public Mono apply(DuplexConnection duplexConnection) { + activeConnections.add(duplexConnection); + duplexConnection + .onClose() + .doFinally(__ -> activeConnections.remove(duplexConnection)) + .subscribe(); + return acceptor.apply(duplexConnection); + } + + @Override + public void dispose() { + if (!registry.remove(address.getName(), this)) { + // already disposed + return; + } + + Mono.whenDelayError( + activeConnections + .stream() + .peek(DuplexConnection::dispose) + .map(DuplexConnection::onClose) + .collect(Collectors.toList())) + .subscribe(null, onClose::tryEmitError, onClose::tryEmitEmpty); + } + + @Override + @SuppressWarnings("ConstantConditions") + public boolean isDisposed() { + return onClose.scan(Scannable.Attr.TERMINATED) || onClose.scan(Scannable.Attr.CANCELLED); + } + + @Override + public Mono onClose() { + return onClose.asMono(); + } + } +} diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java new file mode 100644 index 000000000..4d0da126a --- /dev/null +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/LocalSocketAddress.java @@ -0,0 +1,48 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import java.net.SocketAddress; +import java.util.Objects; + +/** An implementation of {@link SocketAddress} representing a local connection. */ +public final class LocalSocketAddress extends SocketAddress { + + private static final long serialVersionUID = -7513338854585475473L; + + private final String name; + + /** + * Creates a new instance. + * + * @param name the name representing the address + * @throws NullPointerException if {@code name} is {@code null} + */ + public LocalSocketAddress(String name) { + this.name = Objects.requireNonNull(name, "name must not be null"); + } + + /** Return the name for this connection. */ + public String getName() { + return name; + } + + @Override + public String toString() { + return "[local address] " + name; + } +} diff --git a/rsocket-transport-local/src/main/java/io/rsocket/transport/local/package-info.java b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/package-info.java new file mode 100644 index 000000000..6a67f6af4 --- /dev/null +++ b/rsocket-transport-local/src/main/java/io/rsocket/transport/local/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The local RSocket transport implementation. */ +@NonNullApi +package io.rsocket.transport.local; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java new file mode 100644 index 000000000..095de3f0e --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalClientTransportTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.rsocket.Closeable; +import java.time.Duration; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.test.StepVerifier; + +final class LocalClientTransportTest { + + @DisplayName("connects to server") + @Test + void connect() { + LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); + + Closeable closeable = + serverTransport.start(duplexConnection -> duplexConnection.receive().then()).block(); + + try { + LocalClientTransport.create(serverTransport.getName()) + .connect() + .doOnNext(d -> d.receive().subscribe()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } finally { + closeable.dispose(); + closeable.onClose().block(Duration.ofSeconds(5)); + } + } + + @DisplayName("generates error if server not started") + @Test + void connectNoServer() { + LocalClientTransport.create("test-name") + .connect() + .as(StepVerifier::create) + .verifyErrorMessage("Could not find server: test-name"); + } + + @DisplayName("creates client") + @Test + void create() { + assertThat(LocalClientTransport.create("test-name")).isNotNull(); + } + + @DisplayName("throws NullPointerException with null name") + @Test + void createNullName() { + assertThatNullPointerException() + .isThrownBy(() -> LocalClientTransport.create(null)) + .withMessage("name must not be null"); + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java new file mode 100644 index 000000000..9228e2d05 --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalPingPong.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.test.PingClient; +import io.rsocket.test.PingHandler; +import java.time.Duration; +import org.HdrHistogram.Recorder; +import reactor.core.publisher.Mono; + +public final class LocalPingPong { + + public static void main(String... args) { + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(LocalServerTransport.create("test-local-server")) + .block(); + + Mono client = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(LocalClientTransport.create("test-local-server")); + + PingClient pingClient = new PingClient(client); + + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); + + int count = 1_000_000_000; + + pingClient + .requestResponsePingPong(count, recorder) + .doOnTerminate(() -> System.out.println("Sent " + count + " messages.")) + .blockLast(); + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java new file mode 100644 index 000000000..28c1dacac --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableTransportTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalResumableTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalResumableTransportTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..8ae16a0a5 --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalResumableWithFragmentationTransportTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalResumableWithFragmentationTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalResumableWithFragmentationTransportTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java new file mode 100644 index 000000000..e4edafc39 --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalServerTransportTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +final class LocalServerTransportTest { + + @DisplayName("create throws NullPointerException with null name") + @Test + void createNullName() { + assertThatNullPointerException() + .isThrownBy(() -> LocalServerTransport.create(null)) + .withMessage("name must not be null"); + } + + @DisplayName("dispose removes name from registry") + @Test + void dispose() { + LocalServerTransport.dispose("test-name"); + } + + @DisplayName("dispose throws NullPointerException with null name") + @Test + void disposeNullName() { + assertThatNullPointerException() + .isThrownBy(() -> LocalServerTransport.dispose(null)) + .withMessage("name must not be null"); + } + + @DisplayName("creates transports with ephemeral names") + @Test + void ephemeral() { + LocalServerTransport serverTransport1 = LocalServerTransport.createEphemeral(); + LocalServerTransport serverTransport2 = LocalServerTransport.createEphemeral(); + + assertThat(serverTransport1.getName()).isNotEqualTo(serverTransport2.getName()); + } + + @DisplayName("returns the server by name") + @Test + void findServer() { + LocalServerTransport serverTransport = LocalServerTransport.createEphemeral(); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + + assertThat(LocalServerTransport.findServer(serverTransport.getName())).isNotNull(); + } + + @DisplayName("returns null if server hasn't been started") + @Test + void findServerMissingName() { + assertThat(LocalServerTransport.findServer("test-name")).isNull(); + } + + @DisplayName("findServer throws NullPointerException with null name") + @Test + void findServerNullName() { + assertThatNullPointerException() + .isThrownBy(() -> LocalServerTransport.findServer(null)) + .withMessage("name must not be null"); + } + + @DisplayName("creates transport with name") + @Test + void named() { + LocalServerTransport serverTransport = LocalServerTransport.create("test-name"); + + assertThat(serverTransport.getName()).isEqualTo("test-name"); + } + + @DisplayName("starts local server transport") + @Test + void start() { + LocalServerTransport ephemeral = LocalServerTransport.createEphemeral(); + try { + ephemeral + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } finally { + LocalServerTransport.dispose(ephemeral.getName()); + } + } + + @DisplayName("start throws NullPointerException with null acceptor") + @Test + void startNullAcceptor() { + assertThatNullPointerException() + .isThrownBy(() -> LocalServerTransport.createEphemeral().start(null)) + .withMessage("acceptor must not be null"); + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalSocketAddressTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalSocketAddressTest.java new file mode 100644 index 000000000..8ad7b70ce --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalSocketAddressTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +final class LocalSocketAddressTest { + + @DisplayName("constructor throws NullPointerException with null name") + @Test + void constructorNullName() { + assertThatNullPointerException() + .isThrownBy(() -> new LocalSocketAddress(null)) + .withMessage("name must not be null"); + } + + @DisplayName("returns the configured name") + @Test + void name() { + assertThat(new LocalSocketAddress("test-name").getName()).isEqualTo("test-name"); + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java new file mode 100644 index 000000000..87ad2105b --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportTest.java @@ -0,0 +1,47 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalTransportTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> "LocalTransportTest-" + testInfo.getDisplayName() + "-" + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address)); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java new file mode 100644 index 000000000..3ca5f5911 --- /dev/null +++ b/rsocket-transport-local/src/test/java/io/rsocket/transport/local/LocalTransportWithFragmentationTest.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.local; + +import io.rsocket.test.TransportTest; +import java.time.Duration; +import java.util.UUID; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInfo; + +final class LocalTransportWithFragmentationTest implements TransportTest { + + private TransportPair transportPair; + + @BeforeEach + void createTestPair(TestInfo testInfo) { + transportPair = + new TransportPair<>( + () -> + "LocalTransportWithFragmentationTest-" + + testInfo.getDisplayName() + + "-" + + UUID.randomUUID(), + (address, server, allocator) -> LocalClientTransport.create(address, allocator), + (address, allocator) -> LocalServerTransport.create(address), + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(1); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-local/src/test/resources/logback-test.xml b/rsocket-transport-local/src/test/resources/logback-test.xml new file mode 100644 index 000000000..5c92235c2 --- /dev/null +++ b/rsocket-transport-local/src/test/resources/logback-test.xml @@ -0,0 +1,49 @@ + + + + + + + + %date{HH:mm:ss.SSS} %-10thread %-42logger %msg%n + + + + + ./test-out.log + false + + %-5relative %-5level %logger{35} - %msg%n + + + + + + + + + + + + + + + + + + + diff --git a/rsocket-transport-netty/build.gradle b/rsocket-transport-netty/build.gradle new file mode 100644 index 000000000..39a5ceac5 --- /dev/null +++ b/rsocket-transport-netty/build.gradle @@ -0,0 +1,60 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +plugins { + id 'java-library' + id 'maven-publish' + id 'signing' + id "com.google.osdetector" version "1.4.0" +} + +def os_suffix = "" +if (osdetector.classifier in ["linux-x86_64", "linux-aarch_64", "osx-x86_64", "osx-aarch_64", "windows-x86_64"]) { + os_suffix = "::" + osdetector.classifier +} + +dependencies { + api project(':rsocket-core') + api "io.projectreactor.netty:reactor-netty-core" + api "io.projectreactor.netty:reactor-netty-http" + api 'org.slf4j:slf4j-api' + + testImplementation project(':rsocket-test') + testImplementation 'io.projectreactor:reactor-test' + testImplementation 'org.assertj:assertj-core' + testImplementation 'org.mockito:mockito-core' + testImplementation 'org.mockito:mockito-junit-jupiter' + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.junit.jupiter:junit-jupiter-params' + + testRuntimeOnly 'org.bouncycastle:bcpkix-jdk15on' + testRuntimeOnly 'ch.qos.logback:logback-classic' + testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine' + testRuntimeOnly 'io.netty:netty-tcnative-boringssl-static' + os_suffix +} + +jar { + manifest { + attributes("Automatic-Module-Name": "rsocket.transport.netty") + } +} + +test { + minHeapSize = "512m" + maxHeapSize = "4096m" +} + +description = 'Reactor Netty RSocket transport implementations (TCP, Websocket)' diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java new file mode 100644 index 000000000..d7b368a3e --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/RSocketLengthCodec.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_SIZE; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +/** + * An extension to the Netty {@link LengthFieldBasedFrameDecoder} that encapsulates the + * RSocket-specific frame length header details. + */ +public final class RSocketLengthCodec extends LengthFieldBasedFrameDecoder { + + /** Creates a new instance of the decoder, specifying the RSocket frame length header size. */ + public RSocketLengthCodec() { + this(FRAME_LENGTH_MASK); + } + + /** + * Creates a new instance of the decoder, specifying the RSocket frame length header size. + * + * @param maxFrameLength maximum allowed frame length for incoming rsocket frames + */ + public RSocketLengthCodec(int maxFrameLength) { + super(maxFrameLength, 0, FRAME_LENGTH_SIZE, 0, 0); + } + + /** + * Simplified non-netty focused decode usage. + * + * @param in the input buffer to read data from. + * @return decoded buffer or null is none available. + * @see #decode(ChannelHandlerContext, ByteBuf) + * @throws Exception if any error happens. + */ + public Object decode(ByteBuf in) throws Exception { + return decode(null, in); + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java new file mode 100644 index 000000000..f5d36269c --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/TcpDuplexConnection.java @@ -0,0 +1,98 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.frame.FrameLengthCodec; +import io.rsocket.internal.BaseDuplexConnection; +import java.net.SocketAddress; +import java.util.Objects; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; + +/** An implementation of {@link DuplexConnection} that connects via TCP. */ +public final class TcpDuplexConnection extends BaseDuplexConnection { + private final String side; + private final Connection connection; + + /** + * Creates a new instance + * + * @param connection the {@link Connection} for managing the server + */ + public TcpDuplexConnection(Connection connection) { + this("unknown", connection); + } + + /** + * Creates a new instance + * + * @param connection the {@link Connection} for managing the server + */ + public TcpDuplexConnection(String side, Connection connection) { + this.connection = Objects.requireNonNull(connection, "connection must not be null"); + this.side = side; + + connection.outbound().send(sender).then().doFinally(__ -> connection.dispose()).subscribe(); + } + + @Override + public ByteBufAllocator alloc() { + return connection.channel().alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return connection.channel().remoteAddress(); + } + + @Override + protected void doOnClose() { + connection.dispose(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError(super.onClose(), connection.onTerminate()); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(alloc(), 0, e); + sender.tryEmitFinal(FrameLengthCodec.encode(alloc(), errorFrame.readableBytes(), errorFrame)); + } + + @Override + public Flux receive() { + return connection.inbound().receive().map(FrameLengthCodec::frame); + } + + @Override + public void sendFrame(int streamId, ByteBuf frame) { + super.sendFrame(streamId, FrameLengthCodec.encode(alloc(), frame.readableBytes(), frame)); + } + + @Override + public String toString() { + return "TcpDuplexConnection{" + "side='" + side + '\'' + ", connection=" + connection + '}'; + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java new file mode 100644 index 000000000..8f1170c5b --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/WebsocketDuplexConnection.java @@ -0,0 +1,109 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.transport.netty; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.rsocket.DuplexConnection; +import io.rsocket.RSocketErrorException; +import io.rsocket.frame.ErrorFrameCodec; +import io.rsocket.internal.BaseDuplexConnection; +import java.net.SocketAddress; +import java.util.Objects; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; + +/** + * An implementation of {@link DuplexConnection} that connects via a Websocket. + * + *

rsocket-java strongly assumes that each ByteBuf is encoded with the length. This is not true + * for message oriented transports so this must be specifically dropped from Frames sent and + * stitched back on for frames received. + */ +public final class WebsocketDuplexConnection extends BaseDuplexConnection { + private final String side; + private final Connection connection; + + /** + * Creates a new instance + * + * @param connection the {@link Connection} to for managing the server + */ + public WebsocketDuplexConnection(Connection connection) { + this("unknown", connection); + } + + /** + * Creates a new instance + * + * @param connection the {@link Connection} to for managing the server + */ + public WebsocketDuplexConnection(String side, Connection connection) { + this.connection = Objects.requireNonNull(connection, "connection must not be null"); + this.side = side; + + connection + .outbound() + .sendObject(sender.map(BinaryWebSocketFrame::new)) + .then() + .doFinally(__ -> connection.dispose()) + .subscribe(); + } + + @Override + public ByteBufAllocator alloc() { + return connection.channel().alloc(); + } + + @Override + public SocketAddress remoteAddress() { + return connection.channel().remoteAddress(); + } + + @Override + protected void doOnClose() { + connection.dispose(); + } + + @Override + public Mono onClose() { + return Mono.whenDelayError(super.onClose(), connection.onTerminate()); + } + + @Override + public Flux receive() { + return connection.inbound().receive(); + } + + @Override + public void sendErrorAndClose(RSocketErrorException e) { + final ByteBuf errorFrame = ErrorFrameCodec.encode(alloc(), 0, e); + sender.tryEmitFinal(errorFrame); + } + + @Override + public String toString() { + return "WebsocketDuplexConnection{" + + "side='" + + side + + '\'' + + ", connection=" + + connection + + '}'; + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java new file mode 100644 index 000000000..84214b98c --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/TcpClientTransport.java @@ -0,0 +1,121 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.client; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.DuplexConnection; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.RSocketLengthCodec; +import io.rsocket.transport.netty.TcpDuplexConnection; +import java.net.InetSocketAddress; +import java.util.Objects; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; + +/** + * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} via TCP. + */ +public final class TcpClientTransport implements ClientTransport { + + private final TcpClient client; + private final int maxFrameLength; + + private TcpClientTransport(TcpClient client, int maxFrameLength) { + this.client = client; + this.maxFrameLength = maxFrameLength; + } + + /** + * Creates a new instance connecting to localhost + * + * @param port the port to connect to + * @return a new instance + */ + public static TcpClientTransport create(int port) { + TcpClient tcpClient = TcpClient.create().port(port); + return create(tcpClient); + } + + /** + * Creates a new instance + * + * @param bindAddress the address to connect to + * @param port the port to connect to + * @return a new instance + * @throws NullPointerException if {@code bindAddress} is {@code null} + */ + public static TcpClientTransport create(String bindAddress, int port) { + Objects.requireNonNull(bindAddress, "bindAddress must not be null"); + + TcpClient tcpClient = TcpClient.create().host(bindAddress).port(port); + return create(tcpClient); + } + + /** + * Creates a new instance + * + * @param address the address to connect to + * @return a new instance + * @throws NullPointerException if {@code address} is {@code null} + */ + public static TcpClientTransport create(InetSocketAddress address) { + Objects.requireNonNull(address, "address must not be null"); + + TcpClient tcpClient = TcpClient.create().remoteAddress(() -> address); + return create(tcpClient); + } + + /** + * Creates a new instance + * + * @param client the {@link TcpClient} to use + * @return a new instance + * @throws NullPointerException if {@code client} is {@code null} + */ + public static TcpClientTransport create(TcpClient client) { + return create(client, FRAME_LENGTH_MASK); + } + + /** + * Creates a new instance + * + * @param client the {@link TcpClient} to use + * @param maxFrameLength max frame length being sent over the connection + * @return a new instance + * @throws NullPointerException if {@code client} is {@code null} + */ + public static TcpClientTransport create(TcpClient client, int maxFrameLength) { + Objects.requireNonNull(client, "client must not be null"); + + return new TcpClientTransport(client, maxFrameLength); + } + + @Override + public int maxFrameLength() { + return maxFrameLength; + } + + @Override + public Mono connect() { + return client + .doOnConnected(c -> c.addHandlerLast(new RSocketLengthCodec(maxFrameLength))) + .connect() + .map(connection -> new TcpDuplexConnection("client", connection)); + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java new file mode 100644 index 000000000..86be47893 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/WebsocketClientTransport.java @@ -0,0 +1,177 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.client; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; +import io.rsocket.DuplexConnection; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import java.net.InetSocketAddress; +import java.net.URI; +import java.util.Arrays; +import java.util.Objects; +import java.util.function.Consumer; +import reactor.core.publisher.Mono; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.client.WebsocketClientSpec; +import reactor.netty.tcp.TcpClient; + +/** + * An implementation of {@link ClientTransport} that connects to a {@link ServerTransport} over + * WebSocket. + */ +public final class WebsocketClientTransport implements ClientTransport { + + private static final String DEFAULT_PATH = "/"; + + private final HttpClient client; + + private final String path; + + private HttpHeaders headers = new DefaultHttpHeaders(); + + private final WebsocketClientSpec.Builder specBuilder = + WebsocketClientSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK); + + private WebsocketClientTransport(HttpClient client, String path) { + Objects.requireNonNull(client, "HttpClient must not be null"); + Objects.requireNonNull(path, "path must not be null"); + this.client = client; + this.path = path.startsWith("/") ? path : "/" + path; + } + + /** + * Creates a new instance connecting to localhost + * + * @param port the port to connect to + * @return a new instance + */ + public static WebsocketClientTransport create(int port) { + return create(TcpClient.create().port(port)); + } + + /** + * Creates a new instance + * + * @param bindAddress the address to connect to + * @param port the port to connect to + * @return a new instance + * @throws NullPointerException if {@code bindAddress} is {@code null} + */ + public static WebsocketClientTransport create(String bindAddress, int port) { + return create(TcpClient.create().host(bindAddress).port(port)); + } + + /** + * Creates a new instance + * + * @param address the address to connect to + * @return a new instance + * @throws NullPointerException if {@code address} is {@code null} + */ + public static WebsocketClientTransport create(InetSocketAddress address) { + Objects.requireNonNull(address, "address must not be null"); + return create(TcpClient.create().remoteAddress(() -> address)); + } + + /** + * Creates a new instance + * + * @param client the {@link TcpClient} to use + * @return a new instance + * @throws NullPointerException if {@code client} or {@code path} is {@code null} + */ + public static WebsocketClientTransport create(TcpClient client) { + return new WebsocketClientTransport(HttpClient.from(client), DEFAULT_PATH); + } + + /** + * Creates a new instance + * + * @param uri the URI to connect to + * @return a new instance + * @throws NullPointerException if {@code uri} is {@code null} + */ + public static WebsocketClientTransport create(URI uri) { + Objects.requireNonNull(uri, "uri must not be null"); + boolean isSecure = uri.getScheme().equals("wss") || uri.getScheme().equals("https"); + TcpClient client = + (isSecure ? TcpClient.create().secure() : TcpClient.create()) + .host(uri.getHost()) + .port(uri.getPort() == -1 ? (isSecure ? 443 : 80) : uri.getPort()); + return new WebsocketClientTransport(HttpClient.from(client), uri.getPath()); + } + + /** + * Creates a new instance + * + * @param client the {@link HttpClient} to use + * @param path the path to request + * @return a new instance + * @throws NullPointerException if {@code client} or {@code path} is {@code null} + */ + public static WebsocketClientTransport create(HttpClient client, String path) { + return new WebsocketClientTransport(client, path); + } + + /** + * Add a header and value(s) to use for the WebSocket handshake request. + * + * @param name the header name + * @param values the header value(s) + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketClientTransport header(String name, String... values) { + if (values != null) { + Arrays.stream(values).forEach(value -> headers.add(name, value)); + } + return this; + } + + /** + * Provide a consumer to customize properties of the {@link WebsocketClientSpec} to use for + * WebSocket upgrades. The consumer is invoked immediately. + * + * @param configurer the configurer to apply to the spec + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketClientTransport webSocketSpec(Consumer configurer) { + configurer.accept(specBuilder); + return this; + } + + @Override + public int maxFrameLength() { + return specBuilder.build().maxFramePayloadLength(); + } + + @Override + public Mono connect() { + return client + .headers(headers -> headers.add(this.headers)) + .websocket(specBuilder.build()) + .uri(path) + .connect() + .map(connection -> new WebsocketDuplexConnection("client", connection)); + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/package-info.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/package-info.java new file mode 100644 index 000000000..4567f2012 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/client/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The Netty-based RSocket client transport implementations. */ +@NonNullApi +package io.rsocket.transport.netty.client; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/package-info.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/package-info.java new file mode 100644 index 000000000..599500cff --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The Netty-based RSocket transport implementations. */ +@NonNullApi +package io.rsocket.transport.netty; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java new file mode 100644 index 000000000..33cff28b4 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/BaseWebsocketServerTransport.java @@ -0,0 +1,64 @@ +package io.rsocket.transport.netty.server; + +import static io.netty.channel.ChannelHandler.*; +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.transport.ServerTransport; +import java.util.function.Consumer; +import java.util.function.Function; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.WebsocketServerSpec; + +abstract class BaseWebsocketServerTransport< + SELF extends BaseWebsocketServerTransport, T extends Closeable> + implements ServerTransport { + private static final Logger logger = LoggerFactory.getLogger(BaseWebsocketServerTransport.class); + private static final ChannelHandler pongHandler = new PongHandler(); + + static Function serverConfigurer = + server -> server.doOnConnection(connection -> connection.addHandlerLast(pongHandler)); + + final WebsocketServerSpec.Builder specBuilder = + WebsocketServerSpec.builder().maxFramePayloadLength(FRAME_LENGTH_MASK); + + /** + * Provide a consumer to customize properties of the {@link WebsocketServerSpec} to use for + * WebSocket upgrades. The consumer is invoked immediately. + * + * @param configurer the configurer to apply to the spec + * @return the same instance for method chaining + * @since 1.0.1 + */ + @SuppressWarnings("unchecked") + public SELF webSocketSpec(Consumer configurer) { + configurer.accept(specBuilder); + return (SELF) this; + } + + @Override + public int maxFrameLength() { + return specBuilder.build().maxFramePayloadLength(); + } + + @Sharable + private static class PongHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof PongWebSocketFrame) { + logger.debug("received WebSocket Pong Frame"); + ReferenceCountUtil.safeRelease(msg); + ctx.read(); + } else { + ctx.fireChannelRead(msg); + } + } + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java new file mode 100644 index 000000000..7e98905ff --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/CloseableChannel.java @@ -0,0 +1,87 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import io.rsocket.Closeable; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.util.Objects; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableChannel; + +/** + * An implementation of {@link Closeable} that wraps a {@link DisposableChannel}, enabling + * close-ability and exposing the {@link DisposableChannel}'s address. + */ +public final class CloseableChannel implements Closeable { + + /** For forward compatibility: remove when RSocket compiles against Reactor 1.0. */ + private static final Method channelAddressMethod; + + static { + try { + channelAddressMethod = DisposableChannel.class.getMethod("address"); + } catch (NoSuchMethodException ex) { + throw new IllegalStateException("Expected address method", ex); + } + } + + private final DisposableChannel channel; + + /** + * Creates a new instance + * + * @param channel the {@link DisposableChannel} to wrap + * @throws NullPointerException if {@code context} is {@code null} + */ + CloseableChannel(DisposableChannel channel) { + this.channel = Objects.requireNonNull(channel, "channel must not be null"); + } + + /** + * Return local server selector channel address. + * + * @return local {@link InetSocketAddress} + * @see DisposableChannel#address() + */ + public InetSocketAddress address() { + try { + return (InetSocketAddress) channel.address(); + } catch (ClassCastException | NoSuchMethodError e) { + try { + return (InetSocketAddress) channelAddressMethod.invoke(this.channel); + } catch (Exception ex) { + throw new IllegalStateException("Unable to obtain address", ex); + } + } + } + + @Override + public void dispose() { + channel.dispose(); + } + + @Override + public boolean isDisposed() { + return channel.isDisposed(); + } + + @Override + public Mono onClose() { + return channel.onDispose(); + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java new file mode 100644 index 000000000..32562c4a4 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/TcpServerTransport.java @@ -0,0 +1,124 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; + +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.RSocketLengthCodec; +import io.rsocket.transport.netty.TcpDuplexConnection; +import java.net.InetSocketAddress; +import java.util.Objects; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpServer; + +/** + * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via TCP. + */ +public final class TcpServerTransport implements ServerTransport { + + private final TcpServer server; + private final int maxFrameLength; + + private TcpServerTransport(TcpServer server, int maxFrameLength) { + this.server = server; + this.maxFrameLength = maxFrameLength; + } + + /** + * Creates a new instance binding to localhost + * + * @param port the port to bind to + * @return a new instance + */ + public static TcpServerTransport create(int port) { + TcpServer server = TcpServer.create().port(port); + return create(server); + } + + /** + * Creates a new instance + * + * @param bindAddress the address to bind to + * @param port the port to bind to + * @return a new instance + * @throws NullPointerException if {@code bindAddress} is {@code null} + */ + public static TcpServerTransport create(String bindAddress, int port) { + Objects.requireNonNull(bindAddress, "bindAddress must not be null"); + TcpServer server = TcpServer.create().host(bindAddress).port(port); + return create(server); + } + + /** + * Creates a new instance + * + * @param address the address to bind to + * @return a new instance + * @throws NullPointerException if {@code address} is {@code null} + */ + public static TcpServerTransport create(InetSocketAddress address) { + Objects.requireNonNull(address, "address must not be null"); + return create(address.getHostName(), address.getPort()); + } + + /** + * Creates a new instance + * + * @param server the {@link TcpServer} to use + * @return a new instance + * @throws NullPointerException if {@code server} is {@code null} + */ + public static TcpServerTransport create(TcpServer server) { + return create(server, FRAME_LENGTH_MASK); + } + + /** + * Creates a new instance + * + * @param server the {@link TcpServer} to use + * @param maxFrameLength max frame length being sent over the connection + * @return a new instance + * @throws NullPointerException if {@code server} is {@code null} + */ + public static TcpServerTransport create(TcpServer server, int maxFrameLength) { + Objects.requireNonNull(server, "server must not be null"); + return new TcpServerTransport(server, maxFrameLength); + } + + @Override + public int maxFrameLength() { + return maxFrameLength; + } + + @Override + public Mono start(ConnectionAcceptor acceptor) { + Objects.requireNonNull(acceptor, "acceptor must not be null"); + return server + .doOnConnection( + c -> { + c.addHandlerLast(new RSocketLengthCodec(maxFrameLength)); + acceptor + .apply(new TcpDuplexConnection("server", c)) + .then(Mono.never()) + .subscribe(c.disposeSubscriber()); + }) + .bind() + .map(CloseableChannel::new); + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java new file mode 100644 index 000000000..db13720e7 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketRouteTransport.java @@ -0,0 +1,87 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import io.rsocket.Closeable; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import java.util.Objects; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import org.reactivestreams.Publisher; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; +import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.HttpServerRoutes; +import reactor.netty.http.websocket.WebsocketInbound; +import reactor.netty.http.websocket.WebsocketOutbound; + +/** + * An implementation of {@link ServerTransport} that connects via Websocket and listens on specified + * routes. + */ +public final class WebsocketRouteTransport + extends BaseWebsocketServerTransport { + + private final String path; + + private final Consumer routesBuilder; + + private final HttpServer server; + + /** + * Creates a new instance + * + * @param server the {@link HttpServer} to use + * @param routesBuilder the builder for the routes that will be listened on + * @param path the path foe each route + */ + public WebsocketRouteTransport( + HttpServer server, Consumer routesBuilder, String path) { + this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); + this.routesBuilder = Objects.requireNonNull(routesBuilder, "routesBuilder must not be null"); + this.path = Objects.requireNonNull(path, "path must not be null"); + } + + @Override + public Mono start(ConnectionAcceptor acceptor) { + Objects.requireNonNull(acceptor, "acceptor must not be null"); + return server + .route( + routes -> { + routesBuilder.accept(routes); + routes.ws(path, newHandler(acceptor), specBuilder.build()); + }) + .bind() + .map(CloseableChannel::new); + } + + /** + * Creates a new Websocket handler + * + * @param acceptor the {@link ConnectionAcceptor} to use with the handler + * @return a new Websocket handler + * @throws NullPointerException if {@code acceptor} is {@code null} + */ + public static BiFunction> newHandler( + ConnectionAcceptor acceptor) { + return (in, out) -> + acceptor + .apply(new WebsocketDuplexConnection("server", (Connection) in)) + .then(out.neverComplete()); + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java new file mode 100644 index 000000000..4fe736fad --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/WebsocketServerTransport.java @@ -0,0 +1,127 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.WebsocketDuplexConnection; +import java.net.InetSocketAddress; +import java.util.Arrays; +import java.util.Objects; +import reactor.core.publisher.Mono; +import reactor.netty.Connection; +import reactor.netty.http.server.HttpServer; + +/** + * An implementation of {@link ServerTransport} that connects to a {@link ClientTransport} via a + * Websocket. + */ +public final class WebsocketServerTransport + extends BaseWebsocketServerTransport { + + private final HttpServer server; + + private HttpHeaders headers = new DefaultHttpHeaders(); + + private WebsocketServerTransport(HttpServer server) { + this.server = serverConfigurer.apply(Objects.requireNonNull(server, "server must not be null")); + } + + /** + * Creates a new instance binding to localhost + * + * @param port the port to bind to + * @return a new instance + */ + public static WebsocketServerTransport create(int port) { + HttpServer httpServer = HttpServer.create().port(port); + return create(httpServer); + } + + /** + * Creates a new instance + * + * @param bindAddress the address to bind to + * @param port the port to bind to + * @return a new instance + * @throws NullPointerException if {@code bindAddress} is {@code null} + */ + public static WebsocketServerTransport create(String bindAddress, int port) { + Objects.requireNonNull(bindAddress, "bindAddress must not be null"); + HttpServer httpServer = HttpServer.create().host(bindAddress).port(port); + return create(httpServer); + } + + /** + * Creates a new instance + * + * @param address the address to bind to + * @return a new instance + * @throws NullPointerException if {@code address} is {@code null} + */ + public static WebsocketServerTransport create(InetSocketAddress address) { + Objects.requireNonNull(address, "address must not be null"); + return create(address.getHostName(), address.getPort()); + } + + /** + * Creates a new instance + * + * @param server the {@link HttpServer} to use + * @return a new instance + * @throws NullPointerException if {@code server} is {@code null} + */ + public static WebsocketServerTransport create(final HttpServer server) { + Objects.requireNonNull(server, "server must not be null"); + return new WebsocketServerTransport(server); + } + + /** + * Add a header and value(s) to set on the response of WebSocket handshakes. + * + * @param name the header name + * @param values the header value(s) + * @return the same instance for method chaining + * @since 1.0.1 + */ + public WebsocketServerTransport header(String name, String... values) { + if (values != null) { + Arrays.stream(values).forEach(value -> headers.add(name, value)); + } + return this; + } + + @Override + public Mono start(ConnectionAcceptor acceptor) { + Objects.requireNonNull(acceptor, "acceptor must not be null"); + return server + .handle( + (request, response) -> { + response.headers(headers); + return response.sendWebsocket( + (in, out) -> + acceptor + .apply(new WebsocketDuplexConnection("server", (Connection) in)) + .then(out.neverComplete()), + specBuilder.build()); + }) + .bind() + .map(CloseableChannel::new); + } +} diff --git a/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/package-info.java b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/package-info.java new file mode 100644 index 000000000..031844d06 --- /dev/null +++ b/rsocket-transport-netty/src/main/java/io/rsocket/transport/netty/server/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** The Netty-based RSocket server transport implementations. */ +@NonNullApi +package io.rsocket.transport.netty.server; + +import reactor.util.annotation.NonNullApi; diff --git a/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json b/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json new file mode 100644 index 000000000..3a2baa440 --- /dev/null +++ b/rsocket-transport-netty/src/main/resources/META-INF/native-image/io.rsocket/rsocket-transport-netty/reflect-config.json @@ -0,0 +1,16 @@ +[ + { + "condition": { + "typeReachable": "io.rsocket.transport.netty.RSocketLengthCodec" + }, + "name": "io.rsocket.transport.netty.RSocketLengthCodec", + "queryAllPublicMethods": true + }, + { + "condition": { + "typeReachable": "io.rsocket.transport.netty.server.BaseWebsocketServerTransport$PongHandler" + }, + "name": "io.rsocket.transport.netty.server.BaseWebsocketServerTransport$PongHandler", + "queryAllPublicMethods": true + } +] \ No newline at end of file diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java new file mode 100644 index 000000000..23041ec65 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/FragmentTest.java @@ -0,0 +1,184 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.integration; + +import static org.assertj.core.api.Assertions.assertThat; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.RSocketProxy; +import java.util.concurrent.ThreadLocalRandom; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +public class FragmentTest { + private RSocket handler; + private CloseableChannel server; + private String message = null; + private String metaData = null; + private String responseMessage = null; + + private static Stream cases() { + return Stream.of(Arguments.of(0, 64), Arguments.of(64, 0), Arguments.of(64, 64)); + } + + public void startup(int frameSize) { + int randomPort = ThreadLocalRandom.current().nextInt(10_000, 20_000); + StringBuilder message = new StringBuilder(); + StringBuilder responseMessage = new StringBuilder(); + StringBuilder metaData = new StringBuilder(); + for (int i = 0; i < 100; i++) { + message.append("REQUEST "); + responseMessage.append("RESPONSE "); + metaData.append("METADATA "); + } + this.message = message.toString(); + this.responseMessage = responseMessage.toString(); + this.metaData = metaData.toString(); + + TcpServerTransport serverTransport = TcpServerTransport.create("localhost", randomPort); + server = + RSocketServer.create((setup, sendingSocket) -> Mono.just(new RSocketProxy(handler))) + .fragment(frameSize) + .bind(serverTransport) + .block(); + } + + private RSocket buildClient(int frameSize) { + return RSocketConnector.create() + .fragment(frameSize) + .connect(TcpClientTransport.create(server.address())) + .block(); + } + + @AfterEach + public void cleanup() { + server.dispose(); + } + + @ParameterizedTest + @MethodSource("cases") + void testFragmentNoMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); + System.out.println( + "-------------------------------------------------testFragmentNoMetaData-------------------------------------------------"); + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Flux.just(DefaultPayload.create(responseMessage)); + } + }; + + RSocket client = buildClient(clientFrameSize); + + System.out.println("original message: " + message); + System.out.println("original metadata: " + metaData); + Payload payload = client.requestStream(DefaultPayload.create(message)).blockLast(); + System.out.println("response message: " + payload.getDataUtf8()); + System.out.println("response metadata: " + payload.getMetadataUtf8()); + + assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); + } + + @ParameterizedTest + @MethodSource("cases") + void testFragmentRequestMetaDataOnly(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); + System.out.println( + "-------------------------------------------------testFragmentRequestMetaDataOnly-------------------------------------------------"); + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Flux.just(DefaultPayload.create(responseMessage)); + } + }; + + RSocket client = buildClient(clientFrameSize); + + System.out.println("original message: " + message); + System.out.println("original metadata: " + metaData); + Payload payload = client.requestStream(DefaultPayload.create(message, metaData)).blockLast(); + System.out.println("response message: " + payload.getDataUtf8()); + System.out.println("response metadata: " + payload.getMetadataUtf8()); + + assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); + } + + @ParameterizedTest + @MethodSource("cases") + void testFragmentBothMetaData(int clientFrameSize, int serverFrameSize) { + startup(serverFrameSize); + Payload responsePayload = DefaultPayload.create(responseMessage); + System.out.println( + "-------------------------------------------------testFragmentBothMetaData-------------------------------------------------"); + handler = + new RSocket() { + @Override + public Flux requestStream(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Flux.just(DefaultPayload.create(responseMessage, metaData)); + } + + @Override + public Mono requestResponse(Payload payload) { + String request = payload.getDataUtf8(); + String metaData = payload.getMetadataUtf8(); + System.out.println("request message: " + request); + System.out.println("request metadata: " + metaData); + + return Mono.just(DefaultPayload.create(responseMessage, metaData)); + } + }; + + RSocket client = buildClient(clientFrameSize); + + System.out.println("original message: " + message); + System.out.println("original metadata: " + metaData); + Payload payload = client.requestStream(DefaultPayload.create(message, metaData)).blockLast(); + System.out.println("response message: " + payload.getDataUtf8()); + System.out.println("response metadata: " + payload.getMetadataUtf8()); + + assertThat(responseMessage).isEqualTo(payload.getDataUtf8()); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java new file mode 100644 index 000000000..f05713215 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/integration/KeepaliveTest.java @@ -0,0 +1,190 @@ +package io.rsocket.integration; + +import io.rsocket.Payload; +import io.rsocket.RSocket; +import io.rsocket.core.RSocketClient; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.util.DefaultPayload; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; +import reactor.util.retry.Retry; +import reactor.util.retry.RetryBackoffSpec; + +/** + * Test case that reproduces the following GitHub Issue + */ +public class KeepaliveTest { + + private static final Logger LOG = LoggerFactory.getLogger(KeepaliveTest.class); + private static final int PORT = 23200; + + private CloseableChannel server; + + @BeforeEach + void setUp() { + server = createServer().block(); + } + + @AfterEach + void tearDown() { + server.dispose(); + server.onClose().block(); + } + + @Test + void keepAliveTest() { + RSocketClient rsocketClient = createClient(); + + int expectedCount = 4; + AtomicBoolean sleepOnce = new AtomicBoolean(true); + StepVerifier.create( + Flux.range(0, expectedCount) + .delayElements(Duration.ofMillis(2000)) + .concatMap( + i -> + rsocketClient + .requestResponse(Mono.just(DefaultPayload.create(""))) + .doOnNext( + __ -> { + if (sleepOnce.getAndSet(false)) { + try { + LOG.info("Sleeping..."); + Thread.sleep(1_000); + LOG.info("Waking up."); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }) + .log("id " + i) + .onErrorComplete())) + .expectSubscription() + .expectNextCount(expectedCount) + .verifyComplete(); + } + + @Test + void keepAliveTestLazy() { + Mono rsocketMono = createClientLazy(); + + int expectedCount = 4; + AtomicBoolean sleepOnce = new AtomicBoolean(true); + StepVerifier.create( + Flux.range(0, expectedCount) + .delayElements(Duration.ofMillis(2000)) + .concatMap( + i -> + rsocketMono.flatMap( + rsocket -> + rsocket + .requestResponse(DefaultPayload.create("")) + .doOnNext( + __ -> { + if (sleepOnce.getAndSet(false)) { + try { + LOG.info("Sleeping..."); + Thread.sleep(1_000); + LOG.info("Waking up."); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + }) + .log("id " + i) + .onErrorComplete()))) + .expectSubscription() + .expectNextCount(expectedCount) + .verifyComplete(); + } + + private static Mono createServer() { + LOG.info("Starting server at port {}", PORT); + + TcpServer tcpServer = TcpServer.create().host("localhost").port(PORT); + + return RSocketServer.create( + (setupPayload, rSocket) -> { + rSocket + .onClose() + .doFirst(() -> LOG.info("Connected on server side.")) + .doOnTerminate(() -> LOG.info("Connection closed on server side.")) + .subscribe(); + + return Mono.just(new MyServerRsocket()); + }) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create(tcpServer)) + .doOnNext(closeableChannel -> LOG.info("RSocket server started.")); + } + + private static RSocketClient createClient() { + LOG.info("Connecting...."); + + Function reconnectSpec = + reason -> + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(10L)) + .doBeforeRetry(retrySignal -> LOG.info("Reconnecting. Reason: {}", reason)); + + Mono rsocketMono = + RSocketConnector.create() + .fragment(16384) + .reconnect(reconnectSpec.apply("connector-close")) + .keepAlive(Duration.ofMillis(100L), Duration.ofMillis(900L)) + .connect(TcpClientTransport.create(TcpClient.create().host("localhost").port(PORT))); + + RSocketClient client = RSocketClient.from(rsocketMono); + + client + .source() + .doOnNext(r -> LOG.info("Got RSocket")) + .flatMap(RSocket::onClose) + .doOnError(err -> LOG.error("Error during onClose.", err)) + .retryWhen(reconnectSpec.apply("client-close")) + .doFirst(() -> LOG.info("Connected on client side.")) + .doOnTerminate(() -> LOG.info("Connection closed on client side.")) + .repeat() + .subscribe(); + + return client; + } + + private static Mono createClientLazy() { + LOG.info("Connecting...."); + + Function reconnectSpec = + reason -> + Retry.backoff(Long.MAX_VALUE, Duration.ofSeconds(10L)) + .doBeforeRetry(retrySignal -> LOG.info("Reconnecting. Reason: {}", reason)); + + return RSocketConnector.create() + .fragment(16384) + .reconnect(reconnectSpec.apply("connector-close")) + .keepAlive(Duration.ofMillis(100L), Duration.ofMillis(900L)) + .connect(TcpClientTransport.create(TcpClient.create().host("localhost").port(PORT))); + } + + public static class MyServerRsocket implements RSocket { + + @Override + public Mono requestResponse(Payload payload) { + return Mono.just("Pong").map(DefaultPayload::create); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java new file mode 100644 index 000000000..b9c0d4f60 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/RSocketFactoryNettyTransportFragmentationTest.java @@ -0,0 +1,80 @@ +package io.rsocket.transport.netty; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.time.Duration; +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +class RSocketFactoryNettyTransportFragmentationTest { + + static Stream> arguments() { + return Stream.of(TcpServerTransport.create(0), WebsocketServerTransport.create(0)); + } + + @ParameterizedTest + @MethodSource("arguments") + void serverSucceedsWithEnabledFragmentationOnSufficientMtu( + ServerTransport serverTransport) { + Mono server = + RSocketServer.create(mockAcceptor()) + .fragment(100) + .bind(serverTransport) + .doOnNext(CloseableChannel::dispose); + StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("arguments") + void serverSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { + Mono server = + RSocketServer.create(mockAcceptor()) + .bind(serverTransport) + .doOnNext(CloseableChannel::dispose); + StepVerifier.create(server).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("arguments") + void clientSucceedsWithEnabledFragmentationOnSufficientMtu( + ServerTransport serverTransport) { + CloseableChannel server = + RSocketServer.create(mockAcceptor()).fragment(100).bind(serverTransport).block(); + + Mono rSocket = + RSocketConnector.create() + .fragment(100) + .connect(TcpClientTransport.create(server.address())) + .doFinally(s -> server.dispose()); + StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + @ParameterizedTest + @MethodSource("arguments") + void clientSucceedsWithDisabledFragmentation(ServerTransport serverTransport) { + CloseableChannel server = RSocketServer.create(mockAcceptor()).bind(serverTransport).block(); + + Mono rSocket = + RSocketConnector.connectWith(TcpClientTransport.create(server.address())) + .doFinally(s -> server.dispose()); + StepVerifier.create(rSocket).expectNextCount(1).expectComplete().verify(Duration.ofSeconds(5)); + } + + private SocketAcceptor mockAcceptor() { + SocketAcceptor mock = Mockito.mock(SocketAcceptor.class); + Mockito.when(mock.accept(Mockito.any(), Mockito.any())) + .thenReturn(Mono.just(Mockito.mock(RSocket.class))); + return mock; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java new file mode 100644 index 000000000..76c352768 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/SetupRejectionTest.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.transport.netty; + +import io.rsocket.ConnectionSetupPayload; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.exceptions.RejectedSetupException; +import io.rsocket.transport.ClientTransport; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.CloseableChannel; +import io.rsocket.transport.netty.server.TcpServerTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import io.rsocket.util.DefaultPayload; +import java.net.InetSocketAddress; +import java.time.Duration; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Stream; +import org.junit.jupiter.params.provider.Arguments; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.test.StepVerifier; + +public class SetupRejectionTest { + + /* + TODO Fix this test + @DisplayName( + "Rejecting setup by server causes requester RSocket disposal and RejectedSetupException") + @ParameterizedTest + @MethodSource(value = "transports")*/ + void rejectSetupTcp( + Function> serverTransport, + Function clientTransport) { + + String errorMessage = "error"; + RejectingAcceptor acceptor = new RejectingAcceptor(errorMessage); + Mono serverRequester = acceptor.requesterRSocket(); + + CloseableChannel channel = + RSocketServer.create(acceptor) + .bind(serverTransport.apply(new InetSocketAddress("localhost", 0))) + .block(Duration.ofSeconds(5)); + + ErrorConsumer errorConsumer = new ErrorConsumer(); + + RSocket clientRequester = + RSocketConnector.connectWith(clientTransport.apply(channel.address())) + .doOnError(errorConsumer) + .block(Duration.ofSeconds(5)); + + StepVerifier.create(errorConsumer.errors().next()) + .expectNextMatches( + err -> err instanceof RejectedSetupException && errorMessage.equals(err.getMessage())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + StepVerifier.create(clientRequester.onClose()).expectComplete().verify(Duration.ofSeconds(5)); + + StepVerifier.create(serverRequester.flatMap(socket -> socket.onClose())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + StepVerifier.create(clientRequester.requestResponse(DefaultPayload.create("test"))) + .expectErrorMatches( + err -> err instanceof RejectedSetupException && errorMessage.equals(err.getMessage())) + .verify(Duration.ofSeconds(5)); + + channel.dispose(); + } + + static Stream transports() { + Function> tcpServer = + TcpServerTransport::create; + Function> wsServer = + WebsocketServerTransport::create; + Function tcpClient = TcpClientTransport::create; + Function wsClient = WebsocketClientTransport::create; + + return Stream.of(Arguments.of(tcpServer, tcpClient), Arguments.of(wsServer, wsClient)); + } + + static class ErrorConsumer implements Consumer { + private final Sinks.Many errors = Sinks.many().multicast().onBackpressureBuffer(); + + @Override + public void accept(Throwable t) { + errors.tryEmitNext(t); + } + + Flux errors() { + return errors.asFlux(); + } + } + + private static class RejectingAcceptor implements SocketAcceptor { + private final String msg; + private final Sinks.Many requesters = Sinks.many().multicast().onBackpressureBuffer(); + + public RejectingAcceptor(String msg) { + this.msg = msg; + } + + @Override + public Mono accept(ConnectionSetupPayload setup, RSocket sendingSocket) { + requesters.tryEmitNext(sendingSocket); + return Mono.error(new RuntimeException(msg)); + } + + public Mono requesterRSocket() { + return requesters.asFlux().next(); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java new file mode 100644 index 000000000..b17da654f --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpFragmentationTransportTest.java @@ -0,0 +1,60 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(2); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java new file mode 100644 index 000000000..88c64648c --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPing.java @@ -0,0 +1,97 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.Resume; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.test.PerfTest; +import io.rsocket.test.PingClient; +import io.rsocket.transport.netty.client.TcpClientTransport; +import java.time.Duration; +import org.HdrHistogram.Recorder; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; + +@PerfTest +public final class TcpPing { + private static final int INTERACTIONS_COUNT = 1_000_000_000; + private static final int port = Integer.valueOf(System.getProperty("RSOCKET_TEST_PORT", "7878")); + + @BeforeEach + void setUp() { + System.out.println("Starting ping-pong test (TCP transport)"); + System.out.println("port: " + port); + } + + @Test + void requestResponseTest() { + PingClient pingClient = newPingClient(); + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); + + pingClient + .requestResponsePingPong(INTERACTIONS_COUNT, recorder) + .doOnTerminate(() -> System.out.println("Sent " + INTERACTIONS_COUNT + " messages.")) + .blockLast(); + } + + @Test + void requestStreamTest() { + PingClient pingClient = newPingClient(); + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); + + pingClient + .requestStreamPingPong(INTERACTIONS_COUNT, recorder) + .doOnTerminate(() -> System.out.println("Sent " + INTERACTIONS_COUNT + " messages.")) + .blockLast(); + } + + @Test + void requestStreamResumableTest() { + PingClient pingClient = newResumablePingClient(); + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); + + pingClient + .requestStreamPingPong(INTERACTIONS_COUNT, recorder) + .doOnTerminate(() -> System.out.println("Sent " + INTERACTIONS_COUNT + " messages.")) + .blockLast(); + } + + private static PingClient newPingClient() { + return newPingClient(false); + } + + private static PingClient newResumablePingClient() { + return newPingClient(true); + } + + private static PingClient newPingClient(boolean isResumable) { + RSocketConnector connector = RSocketConnector.create(); + if (isResumable) { + connector.resume(new Resume()); + } + Mono rSocket = + connector + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .keepAlive(Duration.ofMinutes(1), Duration.ofMinutes(30)) + .connect(TcpClientTransport.create(port)); + + return new PingClient(rSocket); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java new file mode 100644 index 000000000..338868470 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpPongServer.java @@ -0,0 +1,46 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.rsocket.core.RSocketServer; +import io.rsocket.core.Resume; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.test.PingHandler; +import io.rsocket.transport.netty.server.TcpServerTransport; + +public final class TcpPongServer { + private static final boolean isResume = + Boolean.valueOf(System.getProperty("RSOCKET_TEST_RESUME", "false")); + private static final int port = Integer.valueOf(System.getProperty("RSOCKET_TEST_PORT", "7878")); + + public static void main(String... args) { + System.out.println("Starting TCP ping-pong server"); + System.out.println("port: " + port); + System.out.println("resume enabled: " + isResume); + + RSocketServer server = RSocketServer.create(new PingHandler()); + if (isResume) { + server.resume(new Resume()); + } + server + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(TcpServerTransport.create("localhost", port)) + .block() + .onClose() + .block(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java new file mode 100644 index 000000000..7be1c1c54 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableTransportTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpResumableTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..39b3cec67 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpResumableWithFragmentationTransportTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpResumableWithFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java new file mode 100644 index 000000000..ee49b83cd --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpSecureTransportTest.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.core.Exceptions; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +public class TcpSecureTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE)))), + (address, allocator) -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + TcpServer server = + TcpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return TcpServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(10); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java new file mode 100644 index 000000000..428681f3e --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/TcpTransportTest.java @@ -0,0 +1,59 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.TcpClientTransport; +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.tcp.TcpClient; +import reactor.netty.tcp.TcpServer; + +final class TcpTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + TcpClientTransport.create( + TcpClient.create() + .remoteAddress(server::address) + .option(ChannelOption.ALLOCATOR, allocator)), + (address, allocator) -> { + return TcpServerTransport.create( + TcpServer.create() + .bindAddress(() -> address) + .option(ChannelOption.ALLOCATOR, allocator)); + }); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(2); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClient.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClient.java new file mode 100644 index 000000000..2deb4a4a8 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClient.java @@ -0,0 +1,128 @@ +package io.rsocket.transport.netty; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.websocketx.*; +import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.URI; + +/** + * This is an example of a WebSocket client. + * + *

In order to run this example you need a compatible WebSocket server. Therefore you can either + * start the WebSocket server from the examples or connect to an existing WebSocket server such as + * ws://echo.websocket.org. + * + *

The client will attempt to connect to the URI passed to it as the first argument. You don't + * have to specify any arguments if you want to connect to the example WebSocket server, as this is + * the default. + */ +public final class WebSocketClient { + + static final String URL = System.getProperty("url", "ws://127.0.0.1:7878/websocket"); + + public static void main(String[] args) throws Exception { + URI uri = new URI(URL); + String scheme = uri.getScheme() == null ? "ws" : uri.getScheme(); + final String host = uri.getHost() == null ? "127.0.0.1" : uri.getHost(); + final int port; + if (uri.getPort() == -1) { + if ("ws".equalsIgnoreCase(scheme)) { + port = 80; + } else if ("wss".equalsIgnoreCase(scheme)) { + port = 443; + } else { + port = -1; + } + } else { + port = uri.getPort(); + } + + if (!"ws".equalsIgnoreCase(scheme) && !"wss".equalsIgnoreCase(scheme)) { + System.err.println("Only WS(S) is supported."); + return; + } + + final boolean ssl = "wss".equalsIgnoreCase(scheme); + final SslContext sslCtx; + if (ssl) { + sslCtx = + SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + } else { + sslCtx = null; + } + + EventLoopGroup group = new NioEventLoopGroup(); + try { + // Connect with V13 (RFC 6455 aka HyBi-17). You can change it to V08 or V00. + // If you change it to V00, ping is not supported and remember to change + // HttpResponseDecoder to WebSocketHttpResponseDecoder in the pipeline. + final WebSocketClientHandler handler = + new WebSocketClientHandler( + WebSocketClientHandshakerFactory.newHandshaker( + uri, WebSocketVersion.V13, null, true, new DefaultHttpHeaders())); + + Bootstrap b = new Bootstrap(); + b.group(group) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline p = ch.pipeline(); + if (sslCtx != null) { + p.addLast(sslCtx.newHandler(ch.alloc(), host, port)); + } + p.addLast( + new HttpClientCodec(), + new HttpObjectAggregator(8192), + WebSocketClientCompressionHandler.INSTANCE, + handler); + } + }); + + Channel ch = b.connect(uri.getHost(), port).sync().channel(); + handler.handshakeFuture().sync(); + + BufferedReader console = new BufferedReader(new InputStreamReader(System.in)); + while (true) { + String msg = console.readLine(); + if (msg == null) { + break; + } else if ("bye".equals(msg.toLowerCase())) { + ch.writeAndFlush(new CloseWebSocketFrame()); + ch.closeFuture().sync(); + break; + } else if ("ping".equals(msg.toLowerCase())) { + WebSocketFrame frame = + new PingWebSocketFrame(Unpooled.wrappedBuffer(new byte[] {8, 1, 8, 1})); + ch.writeAndFlush(frame); + } else if ("pong".equals(msg.toLowerCase())) { + WebSocketFrame frame = + new PongWebSocketFrame(Unpooled.wrappedBuffer(new byte[] {8, 1, 8, 1})); + ch.writeAndFlush(frame); + } else { + WebSocketFrame frame = new TextWebSocketFrame(msg); + ch.writeAndFlush(frame); + } + } + } finally { + group.shutdownGracefully(); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClientHandler.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClientHandler.java new file mode 100644 index 000000000..092cad2c7 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketClientHandler.java @@ -0,0 +1,90 @@ +package io.rsocket.transport.netty; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketClientHandshaker; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; +import io.netty.util.CharsetUtil; + +public class WebSocketClientHandler extends SimpleChannelInboundHandler { + + private final WebSocketClientHandshaker handshaker; + private ChannelPromise handshakeFuture; + + public WebSocketClientHandler(WebSocketClientHandshaker handshaker) { + this.handshaker = handshaker; + } + + public ChannelFuture handshakeFuture() { + return handshakeFuture; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + handshakeFuture = ctx.newPromise(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + handshaker.handshake(ctx.channel()); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + System.out.println("WebSocket Client disconnected!"); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + Channel ch = ctx.channel(); + if (!handshaker.isHandshakeComplete()) { + try { + handshaker.finishHandshake(ch, (FullHttpResponse) msg); + System.out.println("WebSocket Client connected!"); + handshakeFuture.setSuccess(); + } catch (WebSocketHandshakeException e) { + System.out.println("WebSocket Client failed to connect"); + handshakeFuture.setFailure(e); + } + return; + } + + if (msg instanceof FullHttpResponse) { + FullHttpResponse response = (FullHttpResponse) msg; + throw new IllegalStateException( + "Unexpected FullHttpResponse (getStatus=" + + response.status() + + ", content=" + + response.content().toString(CharsetUtil.UTF_8) + + ')'); + } + + WebSocketFrame frame = (WebSocketFrame) msg; + if (frame instanceof TextWebSocketFrame) { + TextWebSocketFrame textFrame = (TextWebSocketFrame) frame; + System.out.println("WebSocket Client received message: " + textFrame.text()); + } else if (frame instanceof PongWebSocketFrame) { + System.out.println("WebSocket Client received pong"); + } else if (frame instanceof CloseWebSocketFrame) { + System.out.println("WebSocket Client received closing"); + ch.close(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + cause.printStackTrace(); + if (!handshakeFuture.isDone()) { + handshakeFuture.setFailure(cause); + } + ctx.close(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java new file mode 100644 index 000000000..c418dea0f --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebSocketTransportIntegrationTest.java @@ -0,0 +1,49 @@ +package io.rsocket.transport.netty; + +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketRouteTransport; +import io.rsocket.util.DefaultPayload; +import io.rsocket.util.EmptyPayload; +import java.net.URI; +import java.time.Duration; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +public class WebSocketTransportIntegrationTest { + + @Test + public void sendStreamOfDataWithExternalHttpServerTest() { + ServerTransport.ConnectionAcceptor acceptor = + RSocketServer.create( + SocketAcceptor.forRequestStream( + payload -> + Flux.range(0, 10).map(i -> DefaultPayload.create(String.valueOf(i))))) + .asConnectionAcceptor(); + + DisposableServer server = + HttpServer.create() + .host("localhost") + .route(router -> router.ws("/test", WebsocketRouteTransport.newHandler(acceptor))) + .bindNow(); + + RSocket rsocket = + RSocketConnector.connectWith( + WebsocketClientTransport.create( + URI.create("ws://" + server.host() + ":" + server.port() + "/test"))) + .block(); + + StepVerifier.create(rsocket.requestStream(EmptyPayload.INSTANCE)) + .expectSubscription() + .expectNextCount(10) + .expectComplete() + .verify(Duration.ofMillis(1000)); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java new file mode 100644 index 000000000..a784a43c0 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPing.java @@ -0,0 +1,47 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.rsocket.RSocket; +import io.rsocket.core.RSocketConnector; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.test.PingClient; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import java.time.Duration; +import org.HdrHistogram.Recorder; +import reactor.core.publisher.Mono; + +public final class WebsocketPing { + + public static void main(String... args) { + Mono client = + RSocketConnector.create() + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .connect(WebsocketClientTransport.create(7878)); + + PingClient pingClient = new PingClient(client); + + Recorder recorder = pingClient.startTracker(Duration.ofSeconds(1)); + + int count = 1_000_000_000; + + pingClient + .requestResponsePingPong(count, recorder) + .doOnTerminate(() -> System.out.println("Sent " + count + " messages.")) + .blockLast(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java new file mode 100644 index 000000000..ff0fa75b4 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPingPongIntegrationTest.java @@ -0,0 +1,168 @@ +/* + * Copyright 2015-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.rsocket.transport.netty; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.util.ReferenceCountUtil; +import io.rsocket.Closeable; +import io.rsocket.RSocket; +import io.rsocket.SocketAcceptor; +import io.rsocket.core.RSocketConnector; +import io.rsocket.core.RSocketServer; +import io.rsocket.transport.ServerTransport; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketRouteTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import io.rsocket.util.DefaultPayload; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import reactor.core.Scannable; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +public class WebsocketPingPongIntegrationTest { + private static final String host = "localhost"; + private static final int port = 8088; + + private Closeable server; + + @AfterEach + void tearDown() { + server.dispose(); + } + + @ParameterizedTest + @MethodSource("provideServerTransport") + void webSocketPingPong(ServerTransport serverTransport) { + server = + RSocketServer.create(SocketAcceptor.forRequestResponse(Mono::just)) + .bind(serverTransport) + .block(); + + String expectedData = "data"; + String expectedPing = "ping"; + + PingSender pingSender = new PingSender(); + + HttpClient httpClient = + HttpClient.create() + .tcpConfiguration( + tcpClient -> + tcpClient + .doOnConnected(b -> b.addHandlerLast(pingSender)) + .host(host) + .port(port)); + + RSocket rSocket = + RSocketConnector.connectWith(WebsocketClientTransport.create(httpClient, "/")).block(); + + rSocket + .requestResponse(DefaultPayload.create(expectedData)) + .delaySubscription(pingSender.sendPing(expectedPing)) + .as(StepVerifier::create) + .expectNextMatches(p -> expectedData.equals(p.getDataUtf8())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + pingSender + .receivePong() + .as(StepVerifier::create) + .expectNextMatches(expectedPing::equals) + .expectComplete() + .verify(Duration.ofSeconds(5)); + + rSocket + .requestResponse(DefaultPayload.create(expectedData)) + .delaySubscription(pingSender.sendPong()) + .as(StepVerifier::create) + .expectNextMatches(p -> expectedData.equals(p.getDataUtf8())) + .expectComplete() + .verify(Duration.ofSeconds(5)); + } + + private static Stream provideServerTransport() { + return Stream.of( + Arguments.of(WebsocketServerTransport.create(host, port)), + Arguments.of( + new WebsocketRouteTransport( + HttpServer.create().host(host).port(port), routes -> {}, "/"))); + } + + private static class PingSender extends ChannelInboundHandlerAdapter { + private final Sinks.One channel = Sinks.one(); + private final Sinks.One pong = Sinks.one(); + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof PongWebSocketFrame) { + pong.tryEmitValue(((PongWebSocketFrame) msg).content().toString(StandardCharsets.UTF_8)); + ReferenceCountUtil.safeRelease(msg); + ctx.read(); + } else { + super.channelRead(ctx, msg); + } + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + Channel ch = ctx.channel(); + if (!(channel.scan(Scannable.Attr.TERMINATED)) && ch.isWritable()) { + channel.tryEmitValue(ctx.channel()); + } + super.channelWritabilityChanged(ctx); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + Channel ch = ctx.channel(); + if (ch.isWritable()) { + channel.tryEmitValue(ch); + } + super.handlerAdded(ctx); + } + + public Mono sendPing(String data) { + return send( + new PingWebSocketFrame(Unpooled.wrappedBuffer(data.getBytes(StandardCharsets.UTF_8)))); + } + + public Mono sendPong() { + return send(new PongWebSocketFrame()); + } + + public Mono receivePong() { + return pong.asMono(); + } + + private Mono send(WebSocketFrame webSocketFrame) { + return channel.asMono().doOnNext(ch -> ch.writeAndFlush(webSocketFrame)).then(); + } + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java new file mode 100644 index 000000000..84dc816be --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketPongServer.java @@ -0,0 +1,34 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.rsocket.core.RSocketServer; +import io.rsocket.frame.decoder.PayloadDecoder; +import io.rsocket.test.PingHandler; +import io.rsocket.transport.netty.server.WebsocketServerTransport; + +public final class WebsocketPongServer { + + public static void main(String... args) { + RSocketServer.create(new PingHandler()) + .payloadDecoder(PayloadDecoder.ZERO_COPY) + .bind(WebsocketServerTransport.create(7878)) + .block() + .onClose() + .block(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java new file mode 100644 index 000000000..043f6bc64 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableTransportTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketResumableTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + false, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java new file mode 100644 index 000000000..b1ca65fcc --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketResumableWithFragmentationTransportTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketResumableWithFragmentationTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }, + true, + true); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java new file mode 100644 index 000000000..81f7ffb95 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketSecureTransportTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.core.Exceptions; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketSecureTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> new InetSocketAddress("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .option(ChannelOption.ALLOCATOR, allocator) + .remoteAddress(server::address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE))), + String.format( + "https://%s:%d/", + server.address().getHostName(), server.address().getPort())), + (address, allocator) -> { + try { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + HttpServer server = + HttpServer.create() + .option(ChannelOption.ALLOCATOR, allocator) + .bindAddress(() -> address) + .secure( + ssl -> + ssl.sslContext( + SslContextBuilder.forServer( + ssc.certificate(), ssc.privateKey()))); + return WebsocketServerTransport.create(server); + } catch (CertificateException e) { + throw Exceptions.propagate(e); + } + }); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(5); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java new file mode 100644 index 000000000..cdd507456 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/WebsocketTransportTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2015-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty; + +import io.netty.channel.ChannelOption; +import io.rsocket.test.TransportTest; +import io.rsocket.transport.netty.client.WebsocketClientTransport; +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.time.Duration; +import org.junit.jupiter.api.BeforeEach; +import reactor.netty.http.client.HttpClient; +import reactor.netty.http.server.HttpServer; + +final class WebsocketTransportTest implements TransportTest { + private TransportPair transportPair; + + @BeforeEach + void createTestPair() { + transportPair = + new TransportPair<>( + () -> InetSocketAddress.createUnresolved("localhost", 0), + (address, server, allocator) -> + WebsocketClientTransport.create( + HttpClient.create() + .host(server.address().getHostName()) + .port(server.address().getPort()) + .option(ChannelOption.ALLOCATOR, allocator), + ""), + (address, allocator) -> { + return WebsocketServerTransport.create( + HttpServer.create() + .host(address.getHostName()) + .port(address.getPort()) + .option(ChannelOption.ALLOCATOR, allocator)); + }); + } + + @Override + public Duration getTimeout() { + return Duration.ofMinutes(3); + } + + @Override + public TransportPair getTransportPair() { + return transportPair; + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java new file mode 100644 index 000000000..ac4c6044b --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/TcpClientTransportTest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.rsocket.transport.netty.server.TcpServerTransport; +import java.net.InetSocketAddress; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpClient; +import reactor.test.StepVerifier; + +final class TcpClientTransportTest { + + @DisplayName("connects to server") + @Test + void connect() { + InetSocketAddress address = InetSocketAddress.createUnresolved("localhost", 0); + + TcpServerTransport serverTransport = TcpServerTransport.create(address); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .flatMap(context -> TcpClientTransport.create(context.address()).connect()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("create generates error if server not started") + @Test + void connectNoServer() { + TcpClientTransport.create(8000).connect().as(StepVerifier::create).verifyError(); + } + + @DisplayName("creates client with BindAddress") + @Test + void createBindAddress() { + assertThat(TcpClientTransport.create("test-bind-address", 8000)).isNotNull(); + } + + @DisplayName("creates client with InetSocketAddress") + @Test + void createInetSocketAddress() { + assertThat( + TcpClientTransport.create( + InetSocketAddress.createUnresolved("test-bind-address", 8000))) + .isNotNull(); + } + + @DisplayName("create throws NullPointerException with null bindAddress") + @Test + void createNullBindAddress() { + assertThatNullPointerException() + .isThrownBy(() -> TcpClientTransport.create((String) null, 8000)) + .withMessage("bindAddress must not be null"); + } + + @DisplayName("create throws NullPointerException with null address") + @Test + void createNullInetSocketAddress() { + assertThatNullPointerException() + .isThrownBy(() -> TcpClientTransport.create((InetSocketAddress) null)) + .withMessage("address must not be null"); + } + + @DisplayName("create throws NullPointerException with null client") + @Test + void createNullTcpClient() { + assertThatNullPointerException() + .isThrownBy(() -> TcpClientTransport.create((TcpClient) null)) + .withMessage("client must not be null"); + } + + @DisplayName("creates client with port") + @Test + void createPort() { + assertThat(TcpClientTransport.create(8000)).isNotNull(); + } + + @DisplayName("creates client with TcpClient") + @Test + void createTcpClient() { + assertThat(TcpClientTransport.create(TcpClient.create())).isNotNull(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java new file mode 100644 index 000000000..2a3670251 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/client/WebsocketClientTransportTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.client; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import io.rsocket.transport.netty.server.WebsocketServerTransport; +import java.net.InetSocketAddress; +import java.net.URI; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Mono; +import reactor.netty.http.client.HttpClient; +import reactor.test.StepVerifier; + +@ExtendWith(MockitoExtension.class) +final class WebsocketClientTransportTest { + + @DisplayName("connects to server") + @Test + void connect() { + InetSocketAddress address = InetSocketAddress.createUnresolved("localhost", 0); + + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(address); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .flatMap(context -> WebsocketClientTransport.create(context.address()).connect()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("create generates error if server not started") + @Test + void connectNoServer() { + WebsocketClientTransport.create(8000).connect().as(StepVerifier::create).verifyError(); + } + + @DisplayName("creates client with BindAddress") + @Test + void createBindAddress() { + assertThat(WebsocketClientTransport.create("test-bind-address", 8000)) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("creates client with HttpClient") + @Test + void createHttpClient() { + assertThat(WebsocketClientTransport.create(HttpClient.create(), "/")) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("creates client with HttpClient and path without root") + @Test + void createHttpClientWithPathWithoutRoot() { + assertThat(WebsocketClientTransport.create(HttpClient.create(), "test")) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/test"); + } + + @DisplayName("creates client with InetSocketAddress") + @Test + void createInetSocketAddress() { + assertThat( + WebsocketClientTransport.create( + InetSocketAddress.createUnresolved("test-bind-address", 8000))) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("create throws NullPointerException with null bindAddress") + @Test + void createNullBindAddress() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create(null, 8000)) + .withMessage("host"); + } + + @DisplayName("create throws NullPointerException with null client") + @Test + void createNullHttpClient() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create(null, "/test-path")) + .withMessage("HttpClient must not be null"); + } + + @DisplayName("create throws NullPointerException with null address") + @Test + void createNullInetSocketAddress() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create((InetSocketAddress) null)) + .withMessage("address must not be null"); + } + + @DisplayName("create throws NullPointerException with null path") + @Test + void createNullPath() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create(HttpClient.create(), null)) + .withMessage("path must not be null"); + } + + @DisplayName("create throws NullPointerException with null URI") + @Test + void createNullUri() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketClientTransport.create((URI) null)) + .withMessage("uri must not be null"); + } + + @DisplayName("creates client with port") + @Test + void createPort() { + assertThat(WebsocketClientTransport.create(8000)).isNotNull(); + } + + @DisplayName("creates client with URI") + @Test + void createUri() { + assertThat(WebsocketClientTransport.create(URI.create("ws://test-host"))) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/"); + } + + @DisplayName("creates client with URI path") + @Test + void createUriPath() { + assertThat(WebsocketClientTransport.create(URI.create("ws://test-host/test"))) + .isNotNull() + .hasFieldOrPropertyWithValue("path", "/test"); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java new file mode 100644 index 000000000..bd53a9b3f --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/CloseableChannelTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.netty.DisposableChannel; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; + +final class CloseableChannelTest { + + private final Mono channel = + TcpServer.create().handle((in, out) -> Mono.empty()).bind(); + + @DisplayName("returns the address of the context") + @Test + void address() { + channel + .map(CloseableChannel::new) + .map(CloseableChannel::address) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("creates instance") + @Test + void constructor() { + channel.map(CloseableChannel::new).as(StepVerifier::create).expectNextCount(1).verifyComplete(); + } + + @DisplayName("constructor throws NullPointerException with null context") + @Test + void constructorNullContext() { + assertThatNullPointerException() + .isThrownBy(() -> new CloseableChannel(null)) + .withMessage("channel must not be null"); + } + + @DisplayName("disposes context") + @Test + void dispose() { + channel + .map(CloseableChannel::new) + .delayUntil( + closeable -> { + closeable.dispose(); + return closeable.onClose().log(); + }) + .as(StepVerifier::create) + .assertNext(closeable -> assertThat(closeable.isDisposed()).isTrue()) + .verifyComplete(); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java new file mode 100644 index 000000000..0e14d8f1d --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/TcpServerTransportTest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import java.net.InetSocketAddress; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; + +final class TcpServerTransportTest { + + @DisplayName("creates server with BindAddress") + @Test + void createBindAddress() { + assertThat(TcpServerTransport.create("test-bind-address", 8000)).isNotNull(); + } + + @DisplayName("creates server with InetSocketAddress") + @Test + void createInetSocketAddress() { + assertThat( + TcpServerTransport.create( + InetSocketAddress.createUnresolved("test-bind-address", 8000))) + .isNotNull(); + } + + @DisplayName("create throws NullPointerException with null bindAddress") + @Test + void createNullBindAddress() { + assertThatNullPointerException() + .isThrownBy(() -> TcpServerTransport.create((String) null, 8000)) + .withMessage("bindAddress must not be null"); + } + + @DisplayName("create throws NullPointerException with null address") + @Test + void createNullInetSocketAddress() { + assertThatNullPointerException() + .isThrownBy(() -> TcpServerTransport.create((InetSocketAddress) null)) + .withMessage("address must not be null"); + } + + @DisplayName("create throws NullPointerException with null server") + @Test + void createNullTcpClient() { + assertThatNullPointerException() + .isThrownBy(() -> TcpServerTransport.create((TcpServer) null)) + .withMessage("server must not be null"); + } + + @DisplayName("creates server with port") + @Test + void createPort() { + assertThat(TcpServerTransport.create("localhost", 8000)).isNotNull(); + } + + @DisplayName("creates client with TcpServer") + @Test + void createTcpClient() { + assertThat(TcpServerTransport.create(TcpServer.create())).isNotNull(); + } + + @DisplayName("starts server") + @Test + void start() { + InetSocketAddress address = InetSocketAddress.createUnresolved("localhost", 0); + + TcpServerTransport serverTransport = TcpServerTransport.create(address); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("start throws NullPointerException with null acceptor") + @Test + void startNullAcceptor() { + assertThatNullPointerException() + .isThrownBy(() -> TcpServerTransport.create("localhost", 8000).start(null)) + .withMessage("acceptor must not be null"); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java new file mode 100644 index 000000000..2670b4a4b --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketRouteTransportTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import static org.assertj.core.api.Assertions.assertThatNullPointerException; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.netty.http.server.HttpServer; +import reactor.test.StepVerifier; + +final class WebsocketRouteTransportTest { + + @DisplayName("creates server") + @Test + void constructor() { + new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path"); + } + + @DisplayName("constructor throw NullPointer with null path") + @Test + void constructorNullPath() { + assertThatNullPointerException() + .isThrownBy(() -> new WebsocketRouteTransport(HttpServer.create(), routes -> {}, null)) + .withMessage("path must not be null"); + } + + @DisplayName("constructor throw NullPointer with null routesBuilder") + @Test + void constructorNullRoutesBuilder() { + assertThatNullPointerException() + .isThrownBy(() -> new WebsocketRouteTransport(HttpServer.create(), null, "/test-path")) + .withMessage("routesBuilder must not be null"); + } + + @DisplayName("constructor throw NullPointer with null server") + @Test + void constructorNullServer() { + assertThatNullPointerException() + .isThrownBy(() -> new WebsocketRouteTransport(null, routes -> {}, "/test-path")) + .withMessage("server must not be null"); + } + + @DisplayName("starts server") + @Test + void start() { + WebsocketRouteTransport serverTransport = + new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path"); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("start throw NullPointerException with null acceptor") + @Test + void startNullAcceptor() { + assertThatNullPointerException() + .isThrownBy( + () -> + new WebsocketRouteTransport(HttpServer.create(), routes -> {}, "/test-path") + .start(null)) + .withMessage("acceptor must not be null"); + } +} diff --git a/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java new file mode 100644 index 000000000..540076704 --- /dev/null +++ b/rsocket-transport-netty/src/test/java/io/rsocket/transport/netty/server/WebsocketServerTransportTest.java @@ -0,0 +1,137 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.transport.netty.server; + +import static io.rsocket.frame.FrameLengthCodec.FRAME_LENGTH_MASK; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatNullPointerException; +import static org.mockito.ArgumentMatchers.any; + +import java.net.InetSocketAddress; +import java.util.function.BiFunction; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import reactor.core.publisher.Mono; +import reactor.netty.http.server.HttpServer; +import reactor.netty.http.server.HttpServerRequest; +import reactor.netty.http.server.HttpServerResponse; +import reactor.netty.http.server.WebsocketServerSpec; +import reactor.test.StepVerifier; + +final class WebsocketServerTransportTest { + + @Test + public void testThatSetupWithUnSpecifiedFrameSizeShouldSetMaxFrameSize() { + ArgumentCaptor httpHandlerCaptor = ArgumentCaptor.forClass(BiFunction.class); + HttpServer server = Mockito.spy(HttpServer.create()); + Mockito.doAnswer(a -> server).when(server).handle(httpHandlerCaptor.capture()); + Mockito.doAnswer(a -> server).when(server).doOnConnection(any()); + Mockito.doAnswer(a -> Mono.empty()).when(server).bind(); + + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(server); + serverTransport.start(c -> Mono.empty()).subscribe(); + + HttpServerRequest httpServerRequest = Mockito.mock(HttpServerRequest.class); + HttpServerResponse httpServerResponse = Mockito.mock(HttpServerResponse.class); + + httpHandlerCaptor.getValue().apply(httpServerRequest, httpServerResponse); + + ArgumentCaptor handlerCaptor = ArgumentCaptor.forClass(BiFunction.class); + ArgumentCaptor specCaptor = + ArgumentCaptor.forClass(WebsocketServerSpec.class); + + Mockito.verify(httpServerResponse).sendWebsocket(handlerCaptor.capture(), specCaptor.capture()); + + WebsocketServerSpec spec = specCaptor.getValue(); + assertThat(spec.maxFramePayloadLength()).isEqualTo(FRAME_LENGTH_MASK); + } + + @DisplayName("creates server with BindAddress") + @Test + void createBindAddress() { + assertThat(WebsocketServerTransport.create("test-bind-address", 8000)).isNotNull(); + } + + @DisplayName("creates server with HttpClient") + @Test + void createHttpClient() { + assertThat(WebsocketServerTransport.create(HttpServer.create())).isNotNull(); + } + + @DisplayName("creates server with InetSocketAddress") + @Test + void createInetSocketAddress() { + assertThat( + WebsocketServerTransport.create( + InetSocketAddress.createUnresolved("test-bind-address", 8000))) + .isNotNull(); + } + + @DisplayName("create throws NullPointerException with null bindAddress") + @Test + void createNullBindAddress() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketServerTransport.create(null, 8000)) + .withMessage("bindAddress must not be null"); + } + + @DisplayName("create throws NullPointerException with null client") + @Test + void createNullHttpClient() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketServerTransport.create((HttpServer) null)) + .withMessage("server must not be null"); + } + + @DisplayName("create throws NullPointerException with null address") + @Test + void createNullInetSocketAddress() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketServerTransport.create((InetSocketAddress) null)) + .withMessage("address must not be null"); + } + + @DisplayName("creates server with port") + @Test + void createPort() { + assertThat(WebsocketServerTransport.create(8000)).isNotNull(); + } + + @DisplayName("starts server") + @Test + void start() { + InetSocketAddress address = InetSocketAddress.createUnresolved("localhost", 0); + + WebsocketServerTransport serverTransport = WebsocketServerTransport.create(address); + + serverTransport + .start(duplexConnection -> Mono.empty()) + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); + } + + @DisplayName("start throws NullPointerException with null acceptor") + @Test + void startNullAcceptor() { + assertThatNullPointerException() + .isThrownBy(() -> WebsocketServerTransport.create(8000).start(null)) + .withMessage("acceptor must not be null"); + } +} diff --git a/rsocket-transport-netty/src/test/resources/logback-test.xml b/rsocket-transport-netty/src/test/resources/logback-test.xml new file mode 100644 index 000000000..981d6d0b6 --- /dev/null +++ b/rsocket-transport-netty/src/test/resources/logback-test.xml @@ -0,0 +1,42 @@ + + + + + + + + %date{HH:mm:ss.SSS} %-10thread %-42logger %msg%n + + + + + + + + + + + + + + + + + + + + diff --git a/rsocket-transport-netty/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker b/rsocket-transport-netty/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker new file mode 100644 index 000000000..ca6ee9cea --- /dev/null +++ b/rsocket-transport-netty/src/test/resources/mockito-extensions/org.mockito.plugins.MockMaker @@ -0,0 +1 @@ +mock-maker-inline \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index 5aea80f18..25c3feee5 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1 +1,41 @@ -rootProject.name='reactivesocket' +/* + * Copyright 2015-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id 'com.gradle.enterprise' version '3.1' +} + +rootProject.name = 'rsocket-java' + +include 'rsocket-core' +include 'rsocket-load-balancer' +include 'rsocket-micrometer' +include 'rsocket-test' +include 'rsocket-transport-local' +include 'rsocket-transport-netty' +include 'rsocket-bom' + +include 'rsocket-examples' +include 'benchmarks' + + + +gradleEnterprise { + buildScan { + termsOfServiceUrl = 'https://gradle.com/terms-of-service' + termsOfServiceAgree = 'yes' + } +} + diff --git a/src/main/java/io/reactivesocket/ConnectionSetupPayload.java b/src/main/java/io/reactivesocket/ConnectionSetupPayload.java deleted file mode 100644 index 37e33e89a..000000000 --- a/src/main/java/io/reactivesocket/ConnectionSetupPayload.java +++ /dev/null @@ -1,164 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import java.nio.ByteBuffer; - -import io.reactivesocket.internal.frame.SetupFrameFlyweight; - -/** - * Exposed to server for determination of RequestHandler based on mime types and SETUP metadata/data - */ -public abstract class ConnectionSetupPayload implements Payload -{ - public static final int NO_FLAGS = 0; - public static final int HONOR_LEASE = SetupFrameFlyweight.FLAGS_WILL_HONOR_LEASE; - public static final int STRICT_INTERPRETATION = SetupFrameFlyweight.FLAGS_STRICT_INTERPRETATION; - - public static ConnectionSetupPayload create(String metadataMimeType, String dataMimeType) { - return new ConnectionSetupPayload() { - public String metadataMimeType() - { - return metadataMimeType; - } - - public String dataMimeType() - { - return dataMimeType; - } - - public ByteBuffer getData() - { - return Frame.NULL_BYTEBUFFER; - } - - public ByteBuffer getMetadata() - { - return Frame.NULL_BYTEBUFFER; - } - }; - } - - public static ConnectionSetupPayload create(String metadataMimeType, String dataMimeType, Payload payload) { - return new ConnectionSetupPayload() { - public String metadataMimeType() - { - return metadataMimeType; - } - - public String dataMimeType() - { - return dataMimeType; - } - - public ByteBuffer getData() - { - return payload.getData(); - } - - public ByteBuffer getMetadata() - { - return payload.getMetadata(); - } - }; - } - - public static ConnectionSetupPayload create(String metadataMimeType, String dataMimeType, int flags) - { - return new ConnectionSetupPayload() { - public String metadataMimeType() - { - return metadataMimeType; - } - - public String dataMimeType() - { - return dataMimeType; - } - - public ByteBuffer getData() - { - return Frame.NULL_BYTEBUFFER; - } - - public ByteBuffer getMetadata() - { - return Frame.NULL_BYTEBUFFER; - } - - @Override - public int getFlags() - { - return flags; - } - }; - } - - public static ConnectionSetupPayload create(final Frame setupFrame) - { - Frame.ensureFrameType(FrameType.SETUP, setupFrame); - return new ConnectionSetupPayload() { - public String metadataMimeType() - { - return Frame.Setup.metadataMimeType(setupFrame); - } - - public String dataMimeType() - { - return Frame.Setup.dataMimeType(setupFrame); - } - - public ByteBuffer getData() - { - return setupFrame.getData(); - } - - public ByteBuffer getMetadata() - { - return setupFrame.getMetadata(); - } - - @Override - public int getFlags() - { - return Frame.Setup.getFlags(setupFrame); - } - }; - } - - public abstract String metadataMimeType(); - - public abstract String dataMimeType(); - - public abstract ByteBuffer getData(); - - public abstract ByteBuffer getMetadata(); - - public int getFlags() - { - return HONOR_LEASE; - } - - public boolean willClientHonorLease() - { - return HONOR_LEASE == (getFlags() & HONOR_LEASE); - } - - public boolean doesClientRequestStrictInterpretation() - { - return STRICT_INTERPRETATION == (getFlags() & STRICT_INTERPRETATION); - } -} diff --git a/src/main/java/io/reactivesocket/DefaultReactiveSocket.java b/src/main/java/io/reactivesocket/DefaultReactiveSocket.java deleted file mode 100644 index be2d4dfe6..000000000 --- a/src/main/java/io/reactivesocket/DefaultReactiveSocket.java +++ /dev/null @@ -1,479 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import io.reactivesocket.internal.Requester; -import io.reactivesocket.internal.Responder; -import io.reactivesocket.internal.rx.CompositeCompletable; -import io.reactivesocket.internal.rx.CompositeDisposable; -import io.reactivesocket.rx.Completable; -import io.reactivesocket.rx.Disposable; -import io.reactivesocket.rx.Observable; -import io.reactivesocket.rx.Observer; -import org.agrona.BitUtil; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import java.io.IOException; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Consumer; - -import static io.reactivesocket.LeaseGovernor.NULL_LEASE_GOVERNOR; - -/** - * An implementation of {@link ReactiveSocket} - */ -public class DefaultReactiveSocket implements ReactiveSocket { - private static final RequestHandler EMPTY_HANDLER = new RequestHandler.Builder().build(); - - private static final Consumer DEFAULT_ERROR_STREAM = t -> { - // TODO should we use SLF4j, use System.err, or swallow by default? - System.err.println("ReactiveSocket ERROR => " + t.getMessage() - + " [Provide errorStream handler to replace this default]"); - }; - - private final DuplexConnection connection; - private final boolean isServer; - private final Consumer errorStream; - private Requester requester; - private Responder responder; - private final ConnectionSetupPayload requestorSetupPayload; - private final RequestHandler clientRequestHandler; - private final ConnectionSetupHandler responderConnectionHandler; - private final LeaseGovernor leaseGovernor; - - private DefaultReactiveSocket( - DuplexConnection connection, - boolean isServer, - ConnectionSetupPayload serverRequestorSetupPayload, - RequestHandler clientRequestHandler, - ConnectionSetupHandler responderConnectionHandler, - LeaseGovernor leaseGovernor, - Consumer errorStream - ) { - this.connection = connection; - this.isServer = isServer; - this.requestorSetupPayload = serverRequestorSetupPayload; - this.clientRequestHandler = clientRequestHandler; - this.responderConnectionHandler = responderConnectionHandler; - this.leaseGovernor = leaseGovernor; - this.errorStream = errorStream; - } - - /** - * Create a ReactiveSocket from a client-side {@link DuplexConnection}. - *

- * A client-side connection is one that initiated the connection with a - * server and will define the ReactiveSocket behaviors via the - * {@link ConnectionSetupPayload} that define mime-types, leasing - * behavior and other connection-level details. - * - * @param connection - * DuplexConnection of client-side initiated connection for - * the ReactiveSocket protocol to use. - * @param setup - * ConnectionSetupPayload that defines mime-types and other - * connection behavior details. - * @param handler - * (Optional) RequestHandler for responding to requests from - * the server. If 'null' requests will be responded to with - * "Not Found" errors. - * @param errorStream - * (Optional) Callback for errors while processing streams - * over connection. If 'null' then error messages will be - * output to System.err. - * @return ReactiveSocket for start, shutdown and sending requests. - */ - public static ReactiveSocket fromClientConnection( - DuplexConnection connection, - ConnectionSetupPayload setup, - RequestHandler handler, - Consumer errorStream - ) { - if (connection == null) { - throw new IllegalArgumentException("DuplexConnection can not be null"); - } - if (setup == null) { - throw new IllegalArgumentException("ConnectionSetupPayload can not be null"); - } - final RequestHandler h = handler != null ? handler : EMPTY_HANDLER; - Consumer es = errorStream != null ? errorStream : DEFAULT_ERROR_STREAM; - return new DefaultReactiveSocket(connection, false, setup, h, null, NULL_LEASE_GOVERNOR, es); - } - - /** - * Create a ReactiveSocket from a client-side {@link DuplexConnection}. - *

- * A client-side connection is one that initiated the connection with a - * server and will define the ReactiveSocket behaviors via the - * {@link ConnectionSetupPayload} that define mime-types, leasing - * behavior and other connection-level details. - *

- * If this ReactiveSocket receives requests from the server it will respond - * with "Not Found" errors. - * - * @param connection - * DuplexConnection of client-side initiated connection for the - * ReactiveSocket protocol to use. - * @param setup - * ConnectionSetupPayload that defines mime-types and other - * connection behavior details. - * @param errorStream - * (Optional) Callback for errors while processing streams over - * connection. If 'null' then error messages will be output to - * System.err. - * @return ReactiveSocket for start, shutdown and sending requests. - */ - public static ReactiveSocket fromClientConnection( - DuplexConnection connection, - ConnectionSetupPayload setup, - Consumer errorStream - ) { - return fromClientConnection(connection, setup, EMPTY_HANDLER, errorStream); - } - - public static ReactiveSocket fromClientConnection( - DuplexConnection connection, - ConnectionSetupPayload setup - ) { - return fromClientConnection(connection, setup, EMPTY_HANDLER, DEFAULT_ERROR_STREAM); - } - - /** - * Create a ReactiveSocket from a server-side {@link DuplexConnection}. - *

- * A server-side connection is one that accepted the connection from a - * client and will define the ReactiveSocket behaviors via the - * {@link ConnectionSetupPayload} that define mime-types, leasing behavior - * and other connection-level details. - * - * @param connection - * @param connectionHandler - * @param errorConsumer - * @return - */ - public static ReactiveSocket fromServerConnection( - DuplexConnection connection, - ConnectionSetupHandler connectionHandler, - LeaseGovernor leaseGovernor, - Consumer errorConsumer - ) { - return new DefaultReactiveSocket(connection, true, null, null, connectionHandler, - leaseGovernor, errorConsumer); - } - - public static ReactiveSocket fromServerConnection( - DuplexConnection connection, - ConnectionSetupHandler connectionHandler - ) { - return fromServerConnection(connection, connectionHandler, NULL_LEASE_GOVERNOR, t -> {}); - } - - /** - * Initiate a request response exchange - */ - @Override - public Publisher requestResponse(final Payload payload) { - assertRequester(); - return requester.requestResponse(payload); - } - - @Override - public Publisher fireAndForget(final Payload payload) { - assertRequester(); - return requester.fireAndForget(payload); - } - - @Override - public Publisher requestStream(final Payload payload) { - assertRequester(); - return requester.requestStream(payload); - } - - @Override - public Publisher requestSubscription(final Payload payload) { - assertRequester(); - return requester.requestSubscription(payload); - } - - @Override - public Publisher requestChannel(final Publisher payloads) { - assertRequester(); - return requester.requestChannel(payloads); - } - - @Override - public Publisher metadataPush(final Payload payload) { - assertRequester(); - return requester.metadataPush(payload); - } - - private void assertRequester() { - if (requester == null) { - if (isServer) { - if (responder == null) { - throw new IllegalStateException("Connection not initialized. " + - "Please 'start()' before submitting requests"); - } else { - throw new IllegalStateException("Setup not yet received from client. " + - "Please wait until Setup is completed, then retry."); - } - } else { - throw new IllegalStateException("Connection not initialized. " + - "Please 'start()' before submitting requests"); - } - } - } - - @Override - public double availability() { - // TODO: can happen in either direction - assertRequester(); - return requester.availability(); - } - - @Override - public void sendLease(int ttl, int numberOfRequests) { - // TODO: can happen in either direction - responder.sendLease(ttl, numberOfRequests); - } - - @Override - public final void start(Completable c) { - if (isServer) { - responder = Responder.createServerResponder( - new ConnectionFilter(connection, ConnectionFilter.STREAMS.FROM_CLIENT_EVEN), - responderConnectionHandler, - leaseGovernor, - errorStream, - c, - setupPayload -> { - Completable two = new Completable() { - // wait for 2 success, or 1 error to pass on - AtomicInteger count = new AtomicInteger(); - - @Override - public void success() { - if (count.incrementAndGet() == 2) { - requesterReady.success(); - } - } - - @Override - public void error(Throwable e) { - requesterReady.error(e); - } - }; - requester = Requester.createServerRequester( - new ConnectionFilter(connection, ConnectionFilter.STREAMS.FROM_SERVER_ODD), - setupPayload, - errorStream, - two - ); - two.success(); // now that the reference is assigned in case of synchronous setup - }, - this); - } else { - Completable both = new Completable() { - // wait for 2 success, or 1 error to pass on - AtomicInteger count = new AtomicInteger(); - - @Override - public void success() { - if (count.incrementAndGet() == 2) { - c.success(); - } - } - - @Override - public void error(Throwable e) { - c.error(e); - } - }; - requester = Requester.createClientRequester( - new ConnectionFilter(connection, ConnectionFilter.STREAMS.FROM_CLIENT_EVEN), - requestorSetupPayload, - errorStream, - new Completable() { - @Override - public void success() { - requesterReady.success(); - both.success(); - } - - @Override - public void error(Throwable e) { - requesterReady.error(e); - both.error(e); - } - }); - responder = Responder.createClientResponder( - new ConnectionFilter(connection, ConnectionFilter.STREAMS.FROM_SERVER_ODD), - clientRequestHandler, - leaseGovernor, - errorStream, - both, - this - ); - } - } - - private final CompositeCompletable requesterReady = new CompositeCompletable(); - - @Override - public final void onRequestReady(Completable c) { - requesterReady.add(c); - } - - @Override - public final void onRequestReady(Consumer c) { - requesterReady.add(new Completable() { - @Override - public void success() { - c.accept(null); - } - - @Override - public void error(Throwable e) { - c.accept(e); - } - }); - } - - private static class ConnectionFilter implements DuplexConnection { - private enum STREAMS { - FROM_CLIENT_EVEN, FROM_SERVER_ODD; - } - - private final DuplexConnection connection; - private final STREAMS s; - - private ConnectionFilter(DuplexConnection connection, STREAMS s) { - this.connection = connection; - this.s = s; - } - - @Override - public void close() throws IOException { - connection.close(); // forward - } - - @Override - public Observable getInput() { - return new Observable() { - @Override - public void subscribe(Observer o) { - CompositeDisposable cd = new CompositeDisposable(); - o.onSubscribe(cd); - connection.getInput().subscribe(new Observer() { - - @Override - public void onNext(Frame t) { - int streamId = t.getStreamId(); - FrameType type = t.getType(); - if (streamId == 0) { - if (FrameType.SETUP.equals(type) && s == STREAMS.FROM_CLIENT_EVEN) { - o.onNext(t); - } else if (FrameType.LEASE.equals(type)) { - o.onNext(t); - } else if (FrameType.ERROR.equals(type)) { - // o.onNext(t); // TODO this doesn't work - } else if (FrameType.KEEPALIVE.equals(type)) { - o.onNext(t); // TODO need tests - } else if (FrameType.METADATA_PUSH.equals(type)) { - o.onNext(t); - } - } else if (BitUtil.isEven(streamId)) { - if (s == STREAMS.FROM_CLIENT_EVEN) { - o.onNext(t); - } - } else { - if (s == STREAMS.FROM_SERVER_ODD) { - o.onNext(t); - } - } - } - - @Override - public void onError(Throwable e) { - o.onError(e); - } - - @Override - public void onComplete() { - o.onComplete(); - } - - @Override - public void onSubscribe(Disposable d) { - cd.add(d); - } - }); - } - }; - } - - @Override - public void addOutput(Publisher o, Completable callback) { - connection.addOutput(o, callback); - } - - @Override - public void addOutput(Frame f, Completable callback) { - connection.addOutput(f, callback); - } - - }; - - @Override - public void close() throws Exception { - connection.close(); - leaseGovernor.unregister(responder); - if (requester != null) { - requester.shutdown(); - } - if (responder != null) { - responder.shutdown(); - } - } - - @Override - public void shutdown() { - try { - close(); - } catch (Exception e) { - throw new RuntimeException("Failed Shutdown", e); - } - } - - private static Publisher error(Throwable e) { - return (Subscriber s) -> { - s.onSubscribe(new Subscription() { - @Override - public void request(long n) { - // should probably worry about n==0 - s.onError(e); - } - - @Override - public void cancel() { - // ignoring just because - } - }); - }; - } -} diff --git a/src/main/java/io/reactivesocket/DuplexConnection.java b/src/main/java/io/reactivesocket/DuplexConnection.java deleted file mode 100644 index 772205404..000000000 --- a/src/main/java/io/reactivesocket/DuplexConnection.java +++ /dev/null @@ -1,39 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import io.reactivesocket.rx.Completable; -import io.reactivesocket.rx.Observable; -import org.reactivestreams.Publisher; - -import java.io.Closeable; - -/** - * Represents a connection with input/output that the protocol uses. - */ -public interface DuplexConnection extends Closeable { - - Observable getInput(); - - void addOutput(Publisher o, Completable callback); - - default void addOutput(Frame frame, Completable callback) { - addOutput(s -> { - s.onNext(frame); - s.onComplete(); - }, callback); - } -} diff --git a/src/main/java/io/reactivesocket/Frame.java b/src/main/java/io/reactivesocket/Frame.java deleted file mode 100644 index a71e90eb9..000000000 --- a/src/main/java/io/reactivesocket/Frame.java +++ /dev/null @@ -1,610 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import io.reactivesocket.internal.*; -import io.reactivesocket.internal.frame.ErrorFrameFlyweight; -import io.reactivesocket.internal.frame.FrameHeaderFlyweight; -import io.reactivesocket.internal.frame.FramePool; -import io.reactivesocket.internal.frame.LeaseFrameFlyweight; -import io.reactivesocket.internal.frame.RequestFrameFlyweight; -import io.reactivesocket.internal.frame.RequestNFrameFlyweight; -import io.reactivesocket.internal.frame.SetupFrameFlyweight; -import io.reactivesocket.internal.frame.UnpooledFrame; -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteBuffer; -import java.nio.charset.Charset; - -import static java.lang.System.getProperty; - -/** - * Represents a Frame sent over a {@link DuplexConnection}. - *

- * This provides encoding, decoding and field accessors. - */ -public class Frame implements Payload -{ - public static final ByteBuffer NULL_BYTEBUFFER = FrameHeaderFlyweight.NULL_BYTEBUFFER; - public static final int DATA_MTU = 32 * 1024; - public static final int METADATA_MTU = 32 * 1024; - - /* - * ThreadLocal handling in the pool itself. We don't have a per thread pool at this level. - */ - - private static final String FRAME_POOLER_CLASS_NAME = - getProperty("io.reactivesocket.FramePool", "io.reactivesocket.internal.UnpooledFrame"); - private static final FramePool POOL; - - static - { - FramePool tmpPool; - - try - { - tmpPool = (FramePool)Class.forName(FRAME_POOLER_CLASS_NAME).newInstance(); - } - catch (final Exception ex) - { - tmpPool = new UnpooledFrame(); - } - - POOL = tmpPool; - } - - // not final so we can reuse this object - private MutableDirectBuffer directBuffer; - private int offset = 0; - private int length = 0; - - private Frame(final MutableDirectBuffer directBuffer) - { - this.directBuffer = directBuffer; - } - - /** - * Return underlying {@link ByteBuffer} for frame - * - * @return underlying {@link ByteBuffer} for frame - */ - public ByteBuffer getByteBuffer() { - return directBuffer.byteBuffer(); - } - - /** - * Return {@link ByteBuffer} that is a {@link ByteBuffer#slice()} for the frame data - * - * If no data is present, the ByteBuffer will have 0 capacity. - * - * @return ByteBuffer containing the data - */ - public ByteBuffer getData() - { - return FrameHeaderFlyweight.sliceFrameData(directBuffer, offset, 0); - } - - /** - * Return {@link ByteBuffer} that is a {@link ByteBuffer#slice()} for the frame metadata - * - * If no metadata is present, the ByteBuffer will have 0 capacity. - * - * @return ByteBuffer containing the data - */ - public ByteBuffer getMetadata() - { - return FrameHeaderFlyweight.sliceFrameMetadata(directBuffer, offset, 0); - } - - /** - * Return frame stream identifier - * - * @return frame stream identifier - */ - public int getStreamId() { - return FrameHeaderFlyweight.streamId(directBuffer, offset); - } - - /** - * Return frame {@link FrameType} - * - * @return frame type - */ - public FrameType getType() { - return FrameHeaderFlyweight.frameType(directBuffer, offset); - } - - /** - * Return the offset in the buffer of the frame - * - * @return offset of frame within the buffer - */ - public int offset() - { - return offset; - } - - /** - * Return the encoded length of a frame or the frame length - * - * @return frame length - */ - public int length() - { - return length; - } - - /** - * Return the flags field for the frame - * - * @return frame flags field value - */ - public int flags() - { - return FrameHeaderFlyweight.flags(directBuffer, offset); - } - - /** - * Mutates this Frame to contain the given ByteBuffer - * - * @param byteBuffer to wrap - */ - public void wrap(final ByteBuffer byteBuffer, final int offset) - { - wrap(POOL.acquireMutableDirectBuffer(byteBuffer), offset); - } - - /** - * Mutates this Frame to contain the given MutableDirectBuffer - * - * @param directBuffer to wrap - */ - public void wrap(final MutableDirectBuffer directBuffer, final int offset) - { - this.directBuffer = directBuffer; - this.offset = offset; - } - - /** - * Acquire a free Frame backed by given ByteBuffer - * - * @param byteBuffer to wrap - * @return new {@link Frame} - */ - public static Frame from(final ByteBuffer byteBuffer) { - return POOL.acquireFrame(byteBuffer); - } - - /** - * Acquire a free Frame and back with the given {@link DirectBuffer} starting at offset for length bytes - * - * @param directBuffer to use as backing buffer - * @param offset of start of frame - * @param length of frame in bytes - * @return frame - */ - public static Frame from(final DirectBuffer directBuffer, final int offset, final int length) - { - final Frame frame = POOL.acquireFrame((MutableDirectBuffer)directBuffer); - frame.offset = offset; - frame.length = length; - - return frame; - } - - /** - * Construct a new Frame from the given {@link MutableDirectBuffer} - * - * NOTE: always allocates. Used for pooling. - * - * @param directBuffer to wrap - * @return new {@link Frame} - */ - public static Frame allocate(final MutableDirectBuffer directBuffer) - { - return new Frame(directBuffer); - } - - /** - * Release frame for re-use. - */ - public void release() - { - POOL.release(this.directBuffer); - POOL.release(this); - } - - /** - * Mutates this Frame to contain the given parameters. - * - * NOTE: acquires a new backing buffer and releases current backing buffer - * - * @param streamId to include in frame - * @param type to include in frame - * @param data to include in frame - */ - public void wrap(final int streamId, final FrameType type, final ByteBuffer data) - { - POOL.release(this.directBuffer); - - this.directBuffer = - POOL.acquireMutableDirectBuffer(FrameHeaderFlyweight.computeFrameHeaderLength(type, 0, data.remaining())); - - this.length = FrameHeaderFlyweight.encode(this.directBuffer, offset, streamId, 0, type, NULL_BYTEBUFFER, data); - } - - /* TODO: - * - * fromRequest(type, id, payload) - * fromKeepalive(ByteBuffer data) - * - */ - - // SETUP specific getters - public static class Setup - { - - private Setup() {} - - public static Frame from( - int flags, - int keepaliveInterval, - int maxLifetime, - String metadataMimeType, - String dataMimeType, - Payload payload) - { - final ByteBuffer metadata = payload.getMetadata(); - final ByteBuffer data = payload.getData(); - - final Frame frame = - POOL.acquireFrame(SetupFrameFlyweight.computeFrameLength(metadataMimeType, dataMimeType, metadata.remaining(), data.remaining())); - - frame.length = SetupFrameFlyweight.encode( - frame.directBuffer, frame.offset, flags, keepaliveInterval, maxLifetime, metadataMimeType, dataMimeType, metadata, data); - return frame; - } - - public static int getFlags(final Frame frame) - { - ensureFrameType(FrameType.SETUP, frame); - final int flags = FrameHeaderFlyweight.flags(frame.directBuffer, frame.offset); - - return flags & (SetupFrameFlyweight.FLAGS_WILL_HONOR_LEASE | SetupFrameFlyweight.FLAGS_STRICT_INTERPRETATION); - } - - public static int version(final Frame frame) - { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.version(frame.directBuffer, frame.offset); - } - - public static int keepaliveInterval(final Frame frame) - { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.keepaliveInterval(frame.directBuffer, frame.offset); - } - - public static int maxLifetime(final Frame frame) - { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.maxLifetime(frame.directBuffer, frame.offset); - } - - public static String metadataMimeType(final Frame frame) - { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.metadataMimeType(frame.directBuffer, frame.offset); - } - - public static String dataMimeType(final Frame frame) - { - ensureFrameType(FrameType.SETUP, frame); - return SetupFrameFlyweight.dataMimeType(frame.directBuffer, frame.offset); - } - } - - public static class Error - { - - private Error() {} - - public static Frame from( - int streamId, - final Throwable throwable, - ByteBuffer metadata, - ByteBuffer data - ) { - final int code = ErrorFrameFlyweight.errorCodeFromException(throwable); - final Frame frame = POOL.acquireFrame( - ErrorFrameFlyweight.computeFrameLength(metadata.remaining(), data.remaining())); - - frame.length = ErrorFrameFlyweight.encode( - frame.directBuffer, frame.offset, streamId, code, metadata, data); - return frame; - } - - public static Frame from( - int streamId, - final Throwable throwable, - ByteBuffer metadata - ) { - String data = throwable.getMessage() == null ? "" : throwable.getMessage(); - byte[] bytes = data.getBytes(Charset.forName("UTF-8")); - final ByteBuffer dataBuffer = ByteBuffer.wrap(bytes); - - return from(streamId, throwable, metadata, dataBuffer); - } - - public static Frame from( - int streamId, - final Throwable throwable - ) { - return from(streamId, throwable, NULL_BYTEBUFFER); - } - - public static int errorCode(final Frame frame) - { - ensureFrameType(FrameType.ERROR, frame); - return ErrorFrameFlyweight.errorCode(frame.directBuffer, frame.offset); - } - } - - public static class Lease - { - private Lease() {} - - public static Frame from(int ttl, int numberOfRequests, ByteBuffer metadata) - { - final Frame frame = POOL.acquireFrame(LeaseFrameFlyweight.computeFrameLength(metadata.remaining())); - - frame.length = LeaseFrameFlyweight.encode(frame.directBuffer, frame.offset, ttl, numberOfRequests, metadata); - return frame; - } - - public static int ttl(final Frame frame) - { - ensureFrameType(FrameType.LEASE, frame); - return LeaseFrameFlyweight.ttl(frame.directBuffer, frame.offset); - } - - public static int numberOfRequests(final Frame frame) - { - ensureFrameType(FrameType.LEASE, frame); - return LeaseFrameFlyweight.numRequests(frame.directBuffer, frame.offset); - } - } - - public static class RequestN - { - private RequestN() {} - - public static Frame from(int streamId, int requestN) - { - final Frame frame = POOL.acquireFrame(RequestNFrameFlyweight.computeFrameLength()); - - frame.length = RequestNFrameFlyweight.encode(frame.directBuffer, frame.offset, streamId, requestN); - return frame; - } - - public static long requestN(final Frame frame) - { - ensureFrameType(FrameType.REQUEST_N, frame); - return RequestNFrameFlyweight.requestN(frame.directBuffer, frame.offset); - } - } - - public static class Request - { - private Request() {} - - public static Frame from(int streamId, FrameType type, Payload payload, int initialRequestN) - { - final ByteBuffer d = payload.getData() != null ? payload.getData() : NULL_BYTEBUFFER; - final ByteBuffer md = payload.getMetadata() != null ? payload.getMetadata() : NULL_BYTEBUFFER; - - final Frame frame = POOL.acquireFrame(RequestFrameFlyweight.computeFrameLength(type, md.remaining(), d.remaining())); - - if (type.hasInitialRequestN()) - { - frame.length = RequestFrameFlyweight.encode(frame.directBuffer, frame.offset, streamId, 0, type, initialRequestN, md, d); - } - else - { - frame.length = RequestFrameFlyweight.encode(frame.directBuffer, frame.offset, streamId, 0, type, md, d); - } - - return frame; - } - - public static Frame from(int streamId, FrameType type, int flags) - { - final Frame frame = POOL.acquireFrame(RequestFrameFlyweight.computeFrameLength(type, 0, 0)); - - frame.length = RequestFrameFlyweight.encode(frame.directBuffer, frame.offset, streamId, flags, type, NULL_BYTEBUFFER, NULL_BYTEBUFFER); - return frame; - } - - public static Frame from(int streamId, FrameType type, ByteBuffer metadata, ByteBuffer data, int initialRequestN, int flags) - { - final Frame frame = POOL.acquireFrame(RequestFrameFlyweight.computeFrameLength(type, metadata.remaining(), data.remaining())); - - frame.length = RequestFrameFlyweight.encode(frame.directBuffer, frame.offset, streamId, flags, type, initialRequestN, metadata, data); - return frame; - - } - - public static long initialRequestN(final Frame frame) - { - final FrameType type = frame.getType(); - long result; - - if (!type.isRequestType()) - { - throw new AssertionError("expected request type, but saw " + type.name()); - } - - switch (frame.getType()) - { - case REQUEST_RESPONSE: - result = 1; - break; - case FIRE_AND_FORGET: - result = 0; - break; - default: - result = RequestFrameFlyweight.initialRequestN(frame.directBuffer, frame.offset); - break; - } - - return result; - } - - public static boolean isRequestChannelComplete(final Frame frame) - { - ensureFrameType(FrameType.REQUEST_CHANNEL, frame); - final int flags = FrameHeaderFlyweight.flags(frame.directBuffer, frame.offset); - - return (flags & RequestFrameFlyweight.FLAGS_REQUEST_CHANNEL_C) == RequestFrameFlyweight.FLAGS_REQUEST_CHANNEL_C; - } - } - - public static class Response - { - - private Response() {} - - public static Frame from(int streamId, FrameType type, Payload payload) - { - final ByteBuffer data = payload.getData() != null ? payload.getData() : NULL_BYTEBUFFER; - final ByteBuffer metadata = payload.getMetadata() != null ? payload.getMetadata() : NULL_BYTEBUFFER; - - final Frame frame = - POOL.acquireFrame(FrameHeaderFlyweight.computeFrameHeaderLength(type, metadata.remaining(), data.remaining())); - - frame.length = FrameHeaderFlyweight.encode(frame.directBuffer, frame.offset, streamId, 0, type, metadata, data); - return frame; - } - - public static Frame from(int streamId, FrameType type, ByteBuffer metadata, ByteBuffer data, int flags) - { - final Frame frame = - POOL.acquireFrame(FrameHeaderFlyweight.computeFrameHeaderLength(type, metadata.remaining(), data.remaining())); - - frame.length = FrameHeaderFlyweight.encode(frame.directBuffer, frame.offset, streamId, flags, type, metadata, data); - return frame; - } - - public static Frame from(int streamId, FrameType type) - { - final Frame frame = - POOL.acquireFrame(FrameHeaderFlyweight.computeFrameHeaderLength(type, 0, 0)); - - frame.length = FrameHeaderFlyweight.encode( - frame.directBuffer, frame.offset, streamId, 0, type, Frame.NULL_BYTEBUFFER, Frame.NULL_BYTEBUFFER); - return frame; - } - } - - public static class Cancel - { - - private Cancel() {} - - public static Frame from(int streamId) - { - final Frame frame = - POOL.acquireFrame(FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.CANCEL, 0, 0)); - - frame.length = FrameHeaderFlyweight.encode( - frame.directBuffer, frame.offset, streamId, 0, FrameType.CANCEL, Frame.NULL_BYTEBUFFER, Frame.NULL_BYTEBUFFER); - return frame; - } - } - - public static class Keepalive - { - - private Keepalive() {} - - public static Frame from(ByteBuffer data, boolean respond) - { - final Frame frame = - POOL.acquireFrame(FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.KEEPALIVE, 0, data.remaining())); - - final int flags = respond ? FrameHeaderFlyweight.FLAGS_KEEPALIVE_R : 0; - - frame.length = FrameHeaderFlyweight.encode( - frame.directBuffer, frame.offset, 0, flags, FrameType.KEEPALIVE, Frame.NULL_BYTEBUFFER, data); - - return frame; - } - - public static boolean hasRespondFlag(final Frame frame) - { - ensureFrameType(FrameType.KEEPALIVE, frame); - final int flags = FrameHeaderFlyweight.flags(frame.directBuffer, frame.offset); - - return (flags & FrameHeaderFlyweight.FLAGS_KEEPALIVE_R) == FrameHeaderFlyweight.FLAGS_KEEPALIVE_R; - } - } - - public static void ensureFrameType(final FrameType frameType, final Frame frame) - { - final FrameType typeInFrame = frame.getType(); - - if (typeInFrame != frameType) - { - throw new AssertionError("expected " + frameType + ", but saw" + typeInFrame); - } - } - - @Override - public String toString() { - FrameType type = FrameType.UNDEFINED; - StringBuilder payload = new StringBuilder(); - long streamId = -1; - - try - { - type = FrameHeaderFlyweight.frameType(directBuffer, 0); - ByteBuffer byteBuffer; - byte[] bytes; - - byteBuffer = FrameHeaderFlyweight.sliceFrameMetadata(directBuffer, 0, 0); - if (0 < byteBuffer.capacity()) - { - bytes = new byte[byteBuffer.capacity()]; - byteBuffer.get(bytes); - payload.append(String.format("metadata: \"%s\" ", new String(bytes, Charset.forName("UTF-8")))); - } - - byteBuffer = FrameHeaderFlyweight.sliceFrameData(directBuffer, 0, 0); - if (0 < byteBuffer.capacity()) - { - bytes = new byte[byteBuffer.capacity()]; - byteBuffer.get(bytes); - payload.append(String.format("data: \"%s\"", new String(bytes, Charset.forName("UTF-8")))); - } - - streamId = FrameHeaderFlyweight.streamId(directBuffer, 0); - } catch (Exception e) { - e.printStackTrace(); - } - return "Frame[" + offset + "] => Stream ID: " + streamId + " Type: " + type + " Payload: " + payload; - } -} diff --git a/src/main/java/io/reactivesocket/FrameType.java b/src/main/java/io/reactivesocket/FrameType.java deleted file mode 100644 index 242cf7af8..000000000 --- a/src/main/java/io/reactivesocket/FrameType.java +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -/** - * Types of {@link Frame} that can be sent. - */ -public enum FrameType -{ - // blank type that is not defined - UNDEFINED(0x00), - // Connection - SETUP(0x01, Flags.CAN_HAVE_METADATA_AND_DATA), - LEASE(0x02, Flags.CAN_HAVE_METADATA), - KEEPALIVE(0x03, Flags.CAN_HAVE_DATA), - // Requester to start request - REQUEST_RESPONSE(0x04, Flags.CAN_HAVE_METADATA_AND_DATA | Flags.IS_REQUEST_TYPE), - FIRE_AND_FORGET(0x05, Flags.CAN_HAVE_METADATA_AND_DATA | Flags.IS_REQUEST_TYPE), - REQUEST_STREAM(0x06, Flags.CAN_HAVE_METADATA_AND_DATA | Flags.IS_REQUEST_TYPE | Flags.HAS_INITIAL_REQUEST_N), - REQUEST_SUBSCRIPTION(0x07, Flags.CAN_HAVE_METADATA_AND_DATA | Flags.IS_REQUEST_TYPE | Flags.HAS_INITIAL_REQUEST_N), - REQUEST_CHANNEL(0x08, Flags.CAN_HAVE_METADATA_AND_DATA | Flags.IS_REQUEST_TYPE | Flags.HAS_INITIAL_REQUEST_N), - // Requester mid-stream - REQUEST_N(0x09), - CANCEL(0x0A, Flags.CAN_HAVE_METADATA), - // Responder - RESPONSE(0x0B, Flags.CAN_HAVE_METADATA_AND_DATA), - ERROR(0x0C, Flags.CAN_HAVE_METADATA_AND_DATA), - // Requester & Responder - METADATA_PUSH(0x0D, Flags.CAN_HAVE_METADATA), - // synthetic types from Responder for use by the rest of the machinery - NEXT(0x0E, Flags.CAN_HAVE_METADATA_AND_DATA), - COMPLETE(0x0F), - NEXT_COMPLETE(0x10, Flags.CAN_HAVE_METADATA_AND_DATA); - - private static class Flags - { - private Flags() {} - - private static final int CAN_HAVE_DATA = 0b0001; - private static final int CAN_HAVE_METADATA = 0b0010; - private static final int CAN_HAVE_METADATA_AND_DATA = 0b0011; - private static final int IS_REQUEST_TYPE = 0b0100; - private static final int HAS_INITIAL_REQUEST_N = 0b1000; - } - - private static FrameType[] typesById; - - private final int id; - private final int flags; - - /** - * Index types by id for indexed lookup. - */ - static { - int max = 0; - - for (FrameType t : values()) { - max = Math.max(t.id, max); - } - - typesById = new FrameType[max + 1]; - - for (FrameType t : values()) { - typesById[t.id] = t; - } - } - - FrameType(final int id) - { - this(id, 0); - } - - FrameType(int id, int flags) { - this.id = id; - this.flags = flags; - } - - public int getEncodedType() { - return id; - } - - public boolean isRequestType() - { - return Flags.IS_REQUEST_TYPE == (flags & Flags.IS_REQUEST_TYPE); - } - - public boolean hasInitialRequestN() - { - return Flags.HAS_INITIAL_REQUEST_N == (flags & Flags.HAS_INITIAL_REQUEST_N); - } - - public boolean canHaveData() - { - return Flags.CAN_HAVE_DATA == (flags & Flags.CAN_HAVE_DATA); - } - - public boolean canHaveMetadata() - { - return Flags.CAN_HAVE_METADATA == (flags & Flags.CAN_HAVE_METADATA); - } - - // TODO: offset of metadata and data (simplify parsing) naming: endOfFrameHeaderOffset() - public int payloadOffset() - { - return 0; - } - - public static FrameType from(int id) { - return typesById[id]; - } -} \ No newline at end of file diff --git a/src/main/java/io/reactivesocket/LeaseGovernor.java b/src/main/java/io/reactivesocket/LeaseGovernor.java deleted file mode 100644 index 854958adc..000000000 --- a/src/main/java/io/reactivesocket/LeaseGovernor.java +++ /dev/null @@ -1,37 +0,0 @@ -package io.reactivesocket; - -import io.reactivesocket.internal.Responder; -import io.reactivesocket.lease.NullLeaseGovernor; -import io.reactivesocket.lease.UnlimitedLeaseGovernor; - -public interface LeaseGovernor { - public static final LeaseGovernor NULL_LEASE_GOVERNOR = new NullLeaseGovernor(); - public static final LeaseGovernor UNLIMITED_LEASE_GOVERNOR = new UnlimitedLeaseGovernor(); - - /** - * Register a responder into the LeaseGovernor. - * This give the responsibility to the leaseGovernor to send lease to the responder. - * - * @param responder the responder that will receive lease - */ - public void register(Responder responder); - - /** - * Unregister a responder from the LeaseGovernor. - * Depending on the implementation, this action may trigger a rebalancing of - * the tickets/window to the remaining responders. - * @param responder the responder to be removed - */ - public void unregister(Responder responder); - - /** - * Check if the message received by the responder is valid (i.e. received during a - * valid lease window) - * This action my have side effect in the LeaseGovernor. - * - * @param responder receiving the message - * @param frame the received frame - * @return - */ - public boolean accept(Responder responder, Frame frame); -} diff --git a/src/main/java/io/reactivesocket/ReactiveSocket.java b/src/main/java/io/reactivesocket/ReactiveSocket.java deleted file mode 100644 index 8f8c2a10c..000000000 --- a/src/main/java/io/reactivesocket/ReactiveSocket.java +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import io.reactivesocket.rx.Completable; -import org.reactivestreams.Publisher; - -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; - -/** - * Interface for a connection that supports sending requests and receiving responses - */ -public interface ReactiveSocket extends AutoCloseable { - Publisher requestResponse(final Payload payload); - - Publisher fireAndForget(final Payload payload); - - Publisher requestStream(final Payload payload); - - Publisher requestSubscription(final Payload payload); - - Publisher requestChannel(final Publisher payloads); - - Publisher metadataPush(final Payload payload); - - /** - * Client check for availability to send request based on lease - * - * @return 0.0 to 1.0 indicating availability of sending requests - */ - double availability(); - - /** - * Start protocol processing on the given DuplexConnection. - */ - void start(Completable c); - - /** - * Start and block the current thread until startup is finished. - * - * @throws RuntimeException - * of InterruptedException - */ - default void startAndWait() { - CountDownLatch latch = new CountDownLatch(1); - AtomicReference err = new AtomicReference<>(); - start(new Completable() { - @Override - public void success() { - latch.countDown(); - } - - @Override - public void error(Throwable e) { - latch.countDown(); - } - }); - try { - latch.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - if (err.get() != null) { - throw new RuntimeException(err.get()); - } - } - - /** - * Invoked when Requester is ready. Non-null exception if error. Null if success. - * - * @param c - */ - void onRequestReady(Consumer c); - - /** - * Invoked when Requester is ready with success or fail. - * - * @param c - */ - void onRequestReady(Completable c); - - /** - * Server granting new lease information to client - * - * Initial lease semantics are that server waits for periodic granting of leases by server side. - * - * @param ttl - * @param numberOfRequests - */ - void sendLease(int ttl, int numberOfRequests); - - void shutdown(); -} diff --git a/src/main/java/io/reactivesocket/ReactiveSocketFactory.java b/src/main/java/io/reactivesocket/ReactiveSocketFactory.java deleted file mode 100644 index 11d67f65c..000000000 --- a/src/main/java/io/reactivesocket/ReactiveSocketFactory.java +++ /dev/null @@ -1,116 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import io.reactivesocket.internal.rx.EmptySubscription; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import java.util.NoSuchElementException; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.concurrent.atomic.AtomicBoolean; - -@FunctionalInterface -public interface ReactiveSocketFactory { - - Publisher call(T t); - - /** - * Gets a socket in a blocking manner - * @param t configuration to create the reactive socket - * @return blocks on create the socket - */ - default R callAndWait(T t) { - CompletableFuture future = new CompletableFuture<>(); - - call(t) - .subscribe(new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - s.request(1); - } - - @Override - public void onNext(R reactiveSocket) { - future.complete(reactiveSocket); - } - - @Override - public void onError(Throwable t) { - future.completeExceptionally(t); - } - - @Override - public void onComplete() { - future.completeExceptionally(new NoSuchElementException("Sequence contains no elements")); - } - }); - - return future.join(); - } - - /** - * - * @param t the configuration used to create the reactive socket - * @param timeout timeout - * @param timeUnit timeout units - * @param executorService ScheduledExecutorService to schedule the timeout on - * @return - */ - default Publisher call(T t, long timeout, TimeUnit timeUnit, ScheduledExecutorService executorService) { - Publisher reactiveSocketPublisher = subscriber -> { - AtomicBoolean complete = new AtomicBoolean(); - subscriber.onSubscribe(EmptySubscription.INSTANCE); - call(t) - .subscribe(new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - s.request(1); - } - - @Override - public void onNext(R reactiveSocket) { - subscriber.onNext(reactiveSocket); - } - - @Override - public void onError(Throwable t) { - subscriber.onError(t); - } - - @Override - public void onComplete() { - if (complete.compareAndSet(false, true)) { - subscriber.onComplete(); - } - } - }); - - executorService.schedule(() -> { - if (complete.compareAndSet(false, true)) { - subscriber.onError(new TimeoutException()); - } - }, timeout, timeUnit); - }; - - return reactiveSocketPublisher; - } - -} diff --git a/src/main/java/io/reactivesocket/RequestHandler.java b/src/main/java/io/reactivesocket/RequestHandler.java deleted file mode 100644 index d4f0a821e..000000000 --- a/src/main/java/io/reactivesocket/RequestHandler.java +++ /dev/null @@ -1,140 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import io.reactivesocket.internal.PublisherUtils; -import org.reactivestreams.Publisher; - -import java.util.function.Function; - -public abstract class RequestHandler { - - public static final Function> NO_REQUEST_RESPONSE_HANDLER = - payload -> PublisherUtils.errorPayload(new RuntimeException("No 'requestResponse' handler")); - - public static final Function> NO_REQUEST_STREAM_HANDLER = - payload -> PublisherUtils.errorPayload(new RuntimeException("No 'requestStream' handler")); - - public static final Function> NO_REQUEST_SUBSCRIPTION_HANDLER = - payload -> PublisherUtils.errorPayload(new RuntimeException("No 'requestSubscription' handler")); - - public static final Function> NO_FIRE_AND_FORGET_HANDLER = - payload -> PublisherUtils.errorVoid(new RuntimeException("No 'fireAndForget' handler")); - - public static final Function, Publisher> NO_REQUEST_CHANNEL_HANDLER = - payloads -> PublisherUtils.errorPayload(new RuntimeException("No 'requestChannel' handler")); - - public static final Function> NO_METADATA_PUSH_HANDLER = - payload -> PublisherUtils.errorVoid(new RuntimeException("No 'metadataPush' handler")); - - public abstract Publisher handleRequestResponse(final Payload payload); - - public abstract Publisher handleRequestStream(final Payload payload); - - public abstract Publisher handleSubscription(final Payload payload); - - public abstract Publisher handleFireAndForget(final Payload payload); - - /** - * @note The initialPayload will also be part of the inputs publisher. - * It is there to simplify routing logic. - */ - public abstract Publisher handleChannel(Payload initialPayload, final Publisher inputs); - - public abstract Publisher handleMetadataPush(final Payload payload); - - public static class Builder - { - private Function> handleRequestResponse = NO_REQUEST_RESPONSE_HANDLER; - private Function> handleRequestStream = NO_REQUEST_STREAM_HANDLER; - private Function> handleRequestSubscription = NO_REQUEST_SUBSCRIPTION_HANDLER; - private Function> handleFireAndForget = NO_FIRE_AND_FORGET_HANDLER; - private Function, Publisher> handleRequestChannel = NO_REQUEST_CHANNEL_HANDLER; - private Function> handleMetadataPush = NO_METADATA_PUSH_HANDLER; - - public Builder withRequestResponse(final Function> handleRequestResponse) - { - this.handleRequestResponse = handleRequestResponse; - return this; - } - - public Builder withRequestStream(final Function> handleRequestStream) - { - this.handleRequestStream = handleRequestStream; - return this; - } - - public Builder withRequestSubscription(final Function> handleRequestSubscription) - { - this.handleRequestSubscription = handleRequestSubscription; - return this; - } - - public Builder withFireAndForget(final Function> handleFireAndForget) - { - this.handleFireAndForget = handleFireAndForget; - return this; - } - - public Builder withRequestChannel(final Function , Publisher> handleRequestChannel) - { - this.handleRequestChannel = handleRequestChannel; - return this; - } - - public Builder withMetadataPush(final Function> handleMetadataPush) - { - this.handleMetadataPush = handleMetadataPush; - return this; - } - - public RequestHandler build() - { - return new RequestHandler() - { - public Publisher handleRequestResponse(Payload payload) - { - return handleRequestResponse.apply(payload); - } - - public Publisher handleRequestStream(Payload payload) - { - return handleRequestStream.apply(payload); - } - - public Publisher handleSubscription(Payload payload) - { - return handleRequestSubscription.apply(payload); - } - - public Publisher handleFireAndForget(Payload payload) - { - return handleFireAndForget.apply(payload); - } - - public Publisher handleChannel(Payload initialPayload, Publisher inputs) - { - return handleRequestChannel.apply(inputs); - } - - public Publisher handleMetadataPush(Payload payload) - { - return handleMetadataPush.apply(payload); - } - }; - } - } -} diff --git a/src/main/java/io/reactivesocket/exceptions/Exceptions.java b/src/main/java/io/reactivesocket/exceptions/Exceptions.java deleted file mode 100644 index 80d280675..000000000 --- a/src/main/java/io/reactivesocket/exceptions/Exceptions.java +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.exceptions; - -import io.reactivesocket.Frame; - -import java.nio.ByteBuffer; - -import static io.reactivesocket.internal.frame.ErrorFrameFlyweight.*; -import static java.nio.charset.StandardCharsets.UTF_8; - -public class Exceptions { - - private Exceptions() {} - - public static Throwable from(Frame frame) { - final int errorCode = Frame.Error.errorCode(frame); - String message = ""; - final ByteBuffer byteBuffer = frame.getMetadata(); - if (byteBuffer.hasArray()) { - message = new String(byteBuffer.array(), UTF_8); - } - - Throwable ex; - switch (errorCode) { - case APPLICATION_ERROR: - ex = new ApplicationException(message); - break; - case CONNECTION_ERROR: - ex = new ConnectionException(message); - break; - case INVALID: - ex = new InvalidRequestException(message); - break; - case INVALID_SETUP: - ex = new InvalidSetupException(message); - break; - case REJECTED: - ex = new RejectedException(message); - break; - case REJECTED_SETUP: - ex = new RejectedSetupException(message); - break; - case UNSUPPORTED_SETUP: - ex = new UnsupportedSetupException(message); - break; - default: - ex = new InvalidRequestException("Invalid Error frame"); - } - return ex; - } -} diff --git a/src/main/java/io/reactivesocket/exceptions/InvalidRequestException.java b/src/main/java/io/reactivesocket/exceptions/InvalidRequestException.java deleted file mode 100644 index e915b0413..000000000 --- a/src/main/java/io/reactivesocket/exceptions/InvalidRequestException.java +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.exceptions; - -public class InvalidRequestException extends Throwable { - public InvalidRequestException(String message) { - super(message); - } - - @Override - public synchronized Throwable fillInStackTrace() { - return this; - } -} diff --git a/src/main/java/io/reactivesocket/exceptions/RejectedException.java b/src/main/java/io/reactivesocket/exceptions/RejectedException.java deleted file mode 100644 index 460314b54..000000000 --- a/src/main/java/io/reactivesocket/exceptions/RejectedException.java +++ /dev/null @@ -1,27 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.exceptions; - -public class RejectedException extends Throwable implements Retryable { - public RejectedException (String message) { - super(message); - } - - @Override - public synchronized Throwable fillInStackTrace() { - return this; - } -} diff --git a/src/main/java/io/reactivesocket/exceptions/SetupException.java b/src/main/java/io/reactivesocket/exceptions/SetupException.java deleted file mode 100644 index edf514771..000000000 --- a/src/main/java/io/reactivesocket/exceptions/SetupException.java +++ /dev/null @@ -1,30 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.exceptions; - -import io.reactivesocket.Frame; -import io.reactivesocket.FrameType; - -public class SetupException extends Throwable { - public SetupException(String message) { - super(message); - } - - @Override - public synchronized Throwable fillInStackTrace() { - return this; - } -} diff --git a/src/main/java/io/reactivesocket/internal/FragmentedPublisher.java b/src/main/java/io/reactivesocket/internal/FragmentedPublisher.java deleted file mode 100644 index 4a4f21c6f..000000000 --- a/src/main/java/io/reactivesocket/internal/FragmentedPublisher.java +++ /dev/null @@ -1,57 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import io.reactivesocket.Frame; -import io.reactivesocket.FrameType; -import io.reactivesocket.Payload; -import io.reactivesocket.internal.frame.PayloadFragmenter; - -public class FragmentedPublisher implements Publisher { - - private final PayloadFragmenter fragmenter = new PayloadFragmenter(Frame.METADATA_MTU, Frame.DATA_MTU); - private final Publisher responsePublisher; - private final int streamId; - private final FrameType type; - - public FragmentedPublisher(FrameType type, int streamId, Publisher responsePublisher) { - this.type = type; - this.streamId = streamId; - this.responsePublisher = responsePublisher; - } - - @Override - public void subscribe(Subscriber child) { - child.onSubscribe(new Subscription() { - - @Override - public void request(long n) { - // TODO Auto-generated method stub - - } - - @Override - public void cancel() { - // TODO Auto-generated method stub - - }}); - } - -} diff --git a/src/main/java/io/reactivesocket/internal/PublisherUtils.java b/src/main/java/io/reactivesocket/internal/PublisherUtils.java deleted file mode 100644 index 8c1a6a9c7..000000000 --- a/src/main/java/io/reactivesocket/internal/PublisherUtils.java +++ /dev/null @@ -1,330 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import java.nio.ByteBuffer; -import java.util.Iterator; -import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; - -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import io.reactivesocket.Frame; -import io.reactivesocket.Payload; -import io.reactivesocket.internal.rx.BackpressureHelper; -import io.reactivesocket.internal.rx.BackpressureUtils; -import io.reactivesocket.internal.rx.EmptySubscription; -import io.reactivesocket.internal.rx.SubscriptionHelper; - -public class PublisherUtils { - - private PublisherUtils() {} - - // TODO: be better about using scheduler for this - public static final ScheduledExecutorService SCHEDULER_THREAD = Executors.newScheduledThreadPool(1, - (r) -> { - final Thread thread = new Thread(r); - - thread.setDaemon(true); - - return thread; - }); - - public static final Publisher errorFrame(int streamId, Throwable e) { - return (Subscriber s) -> { - s.onSubscribe(new Subscription() { - - @Override - public void request(long n) { - if (n > 0) { - s.onNext(Frame.Error.from(streamId, e)); - s.onComplete(); - } - } - - @Override - public void cancel() { - // ignoring as nothing to do - } - - }); - - }; - } - - private final static ByteBuffer EMPTY_BYTES = ByteBuffer.allocate(0); - - public static final Publisher errorPayload(Throwable e) { - return (Subscriber s) -> { - s.onSubscribe(new Subscription() { - - @Override - public void request(long n) { - if (n > 0) { - Payload errorPayload = new Payload() { - - @Override - public ByteBuffer getData() { - final byte[] bytes = e.getMessage().getBytes(); - final ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); - return byteBuffer; - } - - @Override - public ByteBuffer getMetadata() { - return EMPTY_BYTES; - } - - }; - s.onNext(errorPayload); - s.onComplete(); - } - } - - @Override - public void cancel() { - // ignoring as nothing to do - } - - }); - - }; - } - - public static final Publisher errorVoid(Throwable e) { - return (Subscriber s) -> { - s.onSubscribe(new Subscription() { - - @Override - public void request(long n) { - } - - @Override - public void cancel() { - // ignoring as nothing to do - } - - }); - s.onError(e); - - }; - } - - public static final Publisher just(Frame frame) { - return (Subscriber s) -> { - s.onSubscribe(new Subscription() { - - boolean completed = false; - - @Override - public void request(long n) { - if (!completed && n > 0) { - completed = true; - s.onNext(frame); - s.onComplete(); - } - } - - @Override - public void cancel() { - // ignoring as nothing to do - } - - }); - - }; - } - - public static final Publisher empty() { - return (Subscriber s) -> { - s.onSubscribe(new Subscription() { - - @Override - public void request(long n) { - } - - @Override - public void cancel() { - // ignoring as nothing to do - } - - }); - s.onComplete(); // TODO confirm this is okay with ReactiveStream spec to send immediately after onSubscribe (I think so since no data is being sent so requestN doesn't matter) - }; - - } - - public static final Publisher keepaliveTicker(final int interval, final TimeUnit timeUnit) { - return (Subscriber s) -> { - s.onSubscribe(new Subscription() - { - final AtomicLong requested = new AtomicLong(0); - final AtomicBoolean started = new AtomicBoolean(false); - volatile ScheduledFuture ticker; - - public void request(long n) - { - BackpressureUtils.getAndAddRequest(requested, n); - if (started.compareAndSet(false, true)) - { - ticker = SCHEDULER_THREAD.scheduleWithFixedDelay(() -> { - final long value = requested.getAndDecrement(); - - if (0 < value) - { - s.onNext(Frame.Keepalive.from(Frame.NULL_BYTEBUFFER, true)); - } - else - { - requested.getAndIncrement(); - } - }, interval, interval, timeUnit); - } - } - - public void cancel() - { - // only used internally and so should not be called before request is done. Race condition exists! - if (null != ticker) - { - ticker.cancel(true); - } - } - }); - }; - } - - public static final Publisher fromIterable(Iterable is) { - return new PublisherIterableSource<>(is); - } - - public static final class PublisherIterableSource extends AtomicBoolean implements Publisher { - /** */ - private static final long serialVersionUID = 9051303031779816842L; - - final Iterable source; - public PublisherIterableSource(Iterable source) { - this.source = source; - } - - @Override - public void subscribe(Subscriber s) { - Iterator it; - try { - it = source.iterator(); - } catch (Throwable e) { - EmptySubscription.error(e, s); - return; - } - boolean hasNext; - try { - hasNext = it.hasNext(); - } catch (Throwable e) { - EmptySubscription.error(e, s); - return; - } - if (!hasNext) { - EmptySubscription.complete(s); - return; - } - s.onSubscribe(new IteratorSourceSubscription<>(it, s)); - } - - static final class IteratorSourceSubscription extends AtomicLong implements Subscription { - /** */ - private static final long serialVersionUID = 8931425802102883003L; - final Iterator it; - final Subscriber subscriber; - - volatile boolean cancelled; - - public IteratorSourceSubscription(Iterator it, Subscriber subscriber) { - this.it = it; - this.subscriber = subscriber; - } - @Override - public void request(long n) { - if (SubscriptionHelper.validateRequest(n)) { - return; - } - if (BackpressureHelper.add(this, n) != 0L) { - return; - } - long r = n; - long r0 = n; - final Subscriber subscriber = this.subscriber; - final Iterator it = this.it; - for (;;) { - if (cancelled) { - return; - } - - long e = 0L; - while (r != 0L) { - T v; - try { - v = it.next(); - } catch (Throwable ex) { - subscriber.onError(ex); - return; - } - - if (v == null) { - subscriber.onError(new NullPointerException("Iterator returned a null element")); - return; - } - - subscriber.onNext(v); - - if (cancelled) { - return; - } - - boolean hasNext; - try { - hasNext = it.hasNext(); - } catch (Throwable ex) { - subscriber.onError(ex); - return; - } - if (!hasNext) { - subscriber.onComplete(); - return; - } - - r--; - e--; - } - if (e != 0L && r0 != Long.MAX_VALUE) { - r = addAndGet(e); - } - if (r == 0L) { - break; - } - } - } - @Override - public void cancel() { - cancelled = true; - } - } - } - - -} diff --git a/src/main/java/io/reactivesocket/internal/Requester.java b/src/main/java/io/reactivesocket/internal/Requester.java deleted file mode 100644 index 586fff298..000000000 --- a/src/main/java/io/reactivesocket/internal/Requester.java +++ /dev/null @@ -1,988 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.charset.Charset; -import java.util.Collection; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; - -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import io.reactivesocket.ConnectionSetupPayload; -import io.reactivesocket.DuplexConnection; -import io.reactivesocket.Frame; -import io.reactivesocket.FrameType; -import io.reactivesocket.Payload; -import io.reactivesocket.exceptions.CancelException; -import io.reactivesocket.exceptions.Exceptions; -import io.reactivesocket.exceptions.Retryable; -import io.reactivesocket.internal.frame.RequestFrameFlyweight; -import io.reactivesocket.internal.rx.BackpressureUtils; -import io.reactivesocket.internal.rx.EmptyDisposable; -import io.reactivesocket.internal.rx.EmptySubscription; -import io.reactivesocket.rx.Completable; -import io.reactivesocket.rx.Disposable; -import io.reactivesocket.rx.Observer; -import org.agrona.collections.Int2ObjectHashMap; - -/** - * Protocol implementation abstracted over a {@link DuplexConnection}. - *

- * Concrete implementations of {@link DuplexConnection} over TCP, WebSockets, Aeron, etc can be passed to this class for protocol handling. - */ -public class Requester { - - private final static Disposable CANCELLED = new EmptyDisposable(); - private final static int KEEPALIVE_INTERVAL_MS = 1000; - - private final boolean isServer; - private final DuplexConnection connection; - private final Int2ObjectHashMap> streamInputMap = new Int2ObjectHashMap<>(); - private final ConnectionSetupPayload setupPayload; - private final Consumer errorStream; - - private final boolean honorLease; - private long ttlExpiration; - private long numberOfRemainingRequests = 0; - private long timeOfLastKeepalive = 0; - private int streamCount = 0; // 0 is reserved for setup, all normal messages are >= 1 - - private static final long DEFAULT_BATCH = 1024; - private static final long REQUEST_THRESHOLD = 256; - - private volatile boolean requesterStarted = false; - - private Requester( - boolean isServer, - DuplexConnection connection, - ConnectionSetupPayload setupPayload, - Consumer errorStream - ) { - this.isServer = isServer; - this.connection = connection; - this.setupPayload = setupPayload; - this.errorStream = errorStream; - if (isServer) { - streamCount = 1; // server is odds - } else { - streamCount = 0; // client is even - } - - this.honorLease = setupPayload.willClientHonorLease(); - } - - public static Requester createClientRequester( - DuplexConnection connection, - ConnectionSetupPayload setupPayload, - Consumer errorStream, - Completable requesterCompletable - ) { - Requester requester = new Requester(false, connection, setupPayload, errorStream); - requester.start(requesterCompletable); - return requester; - } - - public static Requester createServerRequester( - DuplexConnection connection, - ConnectionSetupPayload setupPayload, - Consumer errorStream, - Completable requesterCompletable - ) { - Requester requester = new Requester(true, connection, setupPayload, errorStream); - requester.start(requesterCompletable); - return requester; - } - - public void shutdown() { - // TODO do something here - System.err.println("**** Requester.shutdown => this should actually do something"); - } - - public boolean isServer() { - return isServer; - } - - public long timeOfLastKeepalive() - { - return timeOfLastKeepalive; - } - - /** - * Request/Response with a single message response. - * - * @param payload - * @return - */ - public Publisher requestResponse(final Payload payload) { - return startRequestResponse(nextStreamId(), FrameType.REQUEST_RESPONSE, payload); - } - - /** - * Request/Stream with a finite multi-message response followed by a - * terminal state {@link Subscriber#onComplete()} or - * {@link Subscriber#onError(Throwable)}. - * - * @param payload - * @return - */ - public Publisher requestStream(final Payload payload) { - return startStream(nextStreamId(), FrameType.REQUEST_STREAM, payload); - } - - /** - * Fire-and-forget without a response from the server. - *

- * The returned {@link Publisher} will emit {@link Subscriber#onComplete()} - * or {@link Subscriber#onError(Throwable)} to represent success or failure - * in sending from the client side, but no feedback from the server will - * be returned. - * - * @param payload - * @return - */ - public Publisher fireAndForget(final Payload payload) { - if (payload == null) { - throw new IllegalStateException("Payload can not be null"); - } - assertStarted(); - return child -> child.onSubscribe(new Subscription() { - - final AtomicBoolean started = new AtomicBoolean(false); - - @Override - public void request(long n) { - if (n > 0 && started.compareAndSet(false, true)) { - numberOfRemainingRequests--; - - Frame fnfFrame = Frame.Request.from( - nextStreamId(), FrameType.FIRE_AND_FORGET, payload, 0); - connection.addOutput(fnfFrame, new Completable() { - @Override - public void success() { - child.onComplete(); - } - - @Override - public void error(Throwable e) { - child.onError(e); - } - }); - } - } - - @Override - public void cancel() { - // nothing to cancel on a fire-and-forget - } - }); - } - - /** - * Send asynchonrous Metadata Push without a response from the server. - *

- * The returned {@link Publisher} will emit {@link Subscriber#onComplete()} - * or {@link Subscriber#onError(Throwable)} to represent success or failure - * in sending from the client side, but no feedback from the server will be - * returned. - * - * @param payload - * @return - */ - public Publisher metadataPush(final Payload payload) { - if (payload == null) { - throw new IllegalArgumentException("Payload can not be null"); - } - assertStarted(); - return (Subscriber child) -> - child.onSubscribe(new Subscription() { - - final AtomicBoolean started = new AtomicBoolean(false); - - @Override - public void request(long n) { - if (n > 0 && started.compareAndSet(false, true)) { - numberOfRemainingRequests--; - - Frame metadataPush = Frame.Request.from( - nextStreamId(), FrameType.METADATA_PUSH, payload, 0); - connection.addOutput(metadataPush, new Completable() { - @Override - public void success() { - child.onComplete(); - } - - @Override - public void error(Throwable e) { - child.onError(e); - } - }); - } - } - - @Override - public void cancel() { - // nothing to cancel on a metadataPush - } - }); - } - - - /** - * Event subscription with an infinite multi-message response potentially - * terminated with an {@link Subscriber#onError(Throwable)}. - * - * @param payload - * @return - */ - public Publisher requestSubscription(final Payload payload) { - return startStream(nextStreamId(), FrameType.REQUEST_SUBSCRIPTION, payload); - } - - /** - * Request/Stream with a finite multi-message response followed by a - * terminal state {@link Subscriber#onComplete()} or - * {@link Subscriber#onError(Throwable)}. - * - * @param payloadStream - * @return - */ - public Publisher requestChannel(final Publisher payloadStream) { - return startChannel(nextStreamId(), FrameType.REQUEST_CHANNEL, payloadStream); - } - - private void assertStarted() { - if (!requesterStarted) { - throw new IllegalStateException("Requester not initialized. " + - "Please await 'start()' completion before submitting requests."); - } - } - - - /** - * Return availability of sending requests - * - * @return - */ - public double availability() { - if (!honorLease) { - return 1.0; - } - final long now = System.currentTimeMillis(); - double available = 0.0; - if (numberOfRemainingRequests > 0 && (now < ttlExpiration)) { - available = 1.0; - } - return available; - } - - /* - * Using payload/payloads with null check for efficiency so I don't have to - * allocate a Publisher for the most common case of single Payload - */ - private Publisher startStream(int streamId, FrameType type, Payload payload) { - assertStarted(); - return (Subscriber child) -> { - child.onSubscribe(new Subscription() { - - final AtomicBoolean started = new AtomicBoolean(false); - volatile StreamInputSubscriber streamInputSubscriber; - volatile UnicastSubject writer; - // TODO does this need to be atomic? Can request(n) come from any thread? - final AtomicLong requested = new AtomicLong(); - // TODO AtomicLong just so I can pass it around ... perf issue? or is there a thread-safety issue? - final AtomicLong outstanding = new AtomicLong(); - - @Override - public void request(long n) { - if(n <= 0) { - return; - } - BackpressureUtils.getAndAddRequest(requested, n); - if (started.compareAndSet(false, true)) { - // determine initial RequestN - long currentN = requested.get(); - long requestN = currentN < DEFAULT_BATCH ? currentN : DEFAULT_BATCH; - long threshold = - requestN == DEFAULT_BATCH ? REQUEST_THRESHOLD : requestN / 3; - - // declare output to transport - writer = UnicastSubject.create((w, rn) -> { - numberOfRemainingRequests--; - - // decrement as we request it - requested.addAndGet(-requestN); - // record how many we have requested - outstanding.addAndGet(requestN); - - // when transport connects we write the request frame for this stream - w.onNext(Frame.Request.from(streamId, type, payload, (int)requestN)); - }); - - // Response frames for this Stream - UnicastSubject transportInputSubject = UnicastSubject.create(); - synchronized(Requester.this) { - streamInputMap.put(streamId, transportInputSubject); - } - streamInputSubscriber = new StreamInputSubscriber( - streamId, - threshold, - outstanding, - requested, - writer, - child, - this::cancel - ); - transportInputSubject.subscribe(streamInputSubscriber); - - // connect to transport - connection.addOutput(writer, new Completable() { - @Override - public void success() { - // nothing to do onSuccess - } - - @Override - public void error(Throwable e) { - child.onError(e); - cancel(); - } - }); - } else { - // propagate further requestN frames - long currentN = requested.get(); - long requestThreshold = - REQUEST_THRESHOLD < currentN ? REQUEST_THRESHOLD : currentN / 3; - requestIfNecessary( - streamId, - requestThreshold, - currentN, - outstanding.get(), - writer, - requested, - outstanding - ); - } - - } - - @Override - public void cancel() { - synchronized(Requester.this) { - streamInputMap.remove(streamId); - } - if (!streamInputSubscriber.terminated.get()) { - writer.onNext(Frame.Cancel.from(streamId)); - } - streamInputSubscriber.parentSubscription.cancel(); - } - - }); - }; - } - - /* - * Using payload/payloads with null check for efficiency so I don't have to - * allocate a Publisher for the most common case of single Payload - */ - private Publisher startChannel( - int streamId, - FrameType type, - Publisher payloads - ) { - if (payloads == null) { - throw new IllegalStateException("Both payload and payloads can not be null"); - } - assertStarted(); - return (Subscriber child) -> { - child.onSubscribe(new Subscription() { - - AtomicBoolean started = new AtomicBoolean(false); - volatile StreamInputSubscriber streamInputSubscriber; - volatile UnicastSubject writer; - final AtomicReference payloadsSubscription = new AtomicReference<>(); - // TODO does this need to be atomic? Can request(n) come from any thread? - final AtomicLong requested = new AtomicLong(); - // TODO AtomicLong just so I can pass it around ... perf issue? or is there a thread-safety issue? - final AtomicLong outstanding = new AtomicLong(); - - @Override - public void request(long n) { - if(n <= 0) { - return; - } - BackpressureUtils.getAndAddRequest(requested, n); - if (started.compareAndSet(false, true)) { - // determine initial RequestN - long currentN = requested.get(); - final long requestN = currentN < DEFAULT_BATCH ? currentN : DEFAULT_BATCH; - // threshold - final long threshold = - requestN == DEFAULT_BATCH ? REQUEST_THRESHOLD : requestN / 3; - - // declare output to transport - writer = UnicastSubject.create((w, rn) -> { - numberOfRemainingRequests--; - // decrement as we request it - requested.addAndGet(-requestN); - // record how many we have requested - outstanding.addAndGet(requestN); - - connection.addOutput(new Publisher() { - @Override - public void subscribe(Subscriber transport) { - transport.onSubscribe(new Subscription() { - - final AtomicBoolean started = new AtomicBoolean(false); - @Override - public void request(long n) { - if(n <= 0) { - return; - } - if(started.compareAndSet(false, true)) { - payloads.subscribe(new Subscriber() { - - @Override - public void onSubscribe(Subscription s) { - if (!payloadsSubscription.compareAndSet(null, s)) { - // we are already unsubscribed - s.cancel(); - } else { - // we always start with 1 to initiate - // requestChannel, then wait for REQUEST_N - // from Responder to send more - s.request(1); - } - } - - // onNext is serialized by contract so this is - // okay as non-volatile primitive - boolean isInitialRequest = true; - - @Override - public void onNext(Payload p) { - if(isInitialRequest) { - isInitialRequest = false; - Frame f = Frame.Request.from( - streamId, type, p, (int)requestN); - transport.onNext(f); - } else { - Frame f = Frame.Request.from( - streamId, type, p, 0); - transport.onNext(f); - } - } - - @Override - public void onError(Throwable t) { - // TODO validate with unit tests - RuntimeException exc = new RuntimeException( - "Error received from request stream.", t); - transport.onError(exc); - child.onError(exc); - cancel(); - } - - @Override - public void onComplete() { - Frame f = Frame.Request.from( - streamId, - FrameType.REQUEST_CHANNEL, - RequestFrameFlyweight.FLAGS_REQUEST_CHANNEL_C - ); - transport.onNext(f); - transport.onComplete(); - } - - }); - } else { - // TODO we need to compose this requestN from - // transport with the remote REQUEST_N - } - - } - - @Override - public void cancel() {} - }); - } - }, new Completable() { - @Override - public void success() { - // nothing to do onSuccess - } - - @Override - public void error(Throwable e) { - child.onError(e); - cancel(); - } - }); - - }); - - // Response frames for this Stream - UnicastSubject transportInputSubject = UnicastSubject.create(); - synchronized(Requester.this) { - streamInputMap.put(streamId, transportInputSubject); - } - streamInputSubscriber = new StreamInputSubscriber( - streamId, - threshold, - outstanding, - requested, - writer, - child, - payloadsSubscription, - this::cancel - ); - transportInputSubject.subscribe(streamInputSubscriber); - - // connect to transport - connection.addOutput(writer, new Completable() { - @Override - public void success() { - // nothing to do onSuccess - } - - @Override - public void error(Throwable e) { - child.onError(e); - if (!(e instanceof Retryable)) { - cancel(); - } - } - }); - } else { - // propagate further requestN frames - long currentN = requested.get(); - long requestThreshold = - REQUEST_THRESHOLD < currentN ? REQUEST_THRESHOLD : currentN / 3; - requestIfNecessary( - streamId, - requestThreshold, - currentN, - outstanding.get(), - writer, - requested, - outstanding - ); - } - } - - @Override - public void cancel() { - synchronized(Requester.this) { - streamInputMap.remove(streamId); - } - if (!streamInputSubscriber.terminated.get()) { - writer.onNext(Frame.Cancel.from(streamId)); - } - streamInputSubscriber.parentSubscription.cancel(); - if (payloadsSubscription != null) { - if (!payloadsSubscription.compareAndSet(null, EmptySubscription.INSTANCE)) { - // unsubscribe it if it already exists - payloadsSubscription.get().cancel(); - } - } - } - - }); - }; - } - - /* - * Special-cased for performance reasons (achieved 20-30% throughput - * increase over using startStream for request/response) - */ - private Publisher startRequestResponse(int streamId, FrameType type, Payload payload) { - if (payload == null) { - throw new IllegalStateException("Both payload and payloads can not be null"); - } - assertStarted(); - return (Subscriber child) -> { - child.onSubscribe(new Subscription() { - - final AtomicBoolean started = new AtomicBoolean(false); - volatile StreamInputSubscriber streamInputSubscriber; - volatile UnicastSubject writer; - - @Override - public void request(long n) { - if (n > 0 && started.compareAndSet(false, true)) { - // Response frames for this Stream - UnicastSubject transportInputSubject = UnicastSubject.create(); - synchronized(Requester.this) { - streamInputMap.put(streamId, transportInputSubject); - } - streamInputSubscriber = new StreamInputSubscriber( - streamId, - 0, - null, - null, - writer, - child, - this::cancel - ); - transportInputSubject.subscribe(streamInputSubscriber); - - Frame requestFrame = Frame.Request.from(streamId, type, payload, 1); - // connect to transport - connection.addOutput(requestFrame, new Completable() { - @Override - public void success() { - // nothing to do onSuccess - } - - @Override - public void error(Throwable e) { - child.onError(e); - cancel(); - } - }); - } - } - - @Override - public void cancel() { - if (!streamInputSubscriber.terminated.get()) { - Frame cancelFrame = Frame.Cancel.from(streamId); - connection.addOutput(cancelFrame, new Completable() { - @Override - public void success() { - // nothing to do onSuccess - } - - @Override - public void error(Throwable e) { - child.onError(e); - } - }); - } - synchronized(Requester.this) { - streamInputMap.remove(streamId); - } - streamInputSubscriber.parentSubscription.cancel(); - } - }); - }; - } - - private final static class StreamInputSubscriber implements Subscriber { - final AtomicBoolean terminated = new AtomicBoolean(false); - volatile Subscription parentSubscription; - - private final int streamId; - private final long requestThreshold; - private final AtomicLong outstandingRequests; - private final AtomicLong requested; - private final UnicastSubject writer; - private final Subscriber child; - private final Runnable cancelAction; - private final AtomicReference requestStreamSubscription; - - public StreamInputSubscriber( - int streamId, - long threshold, - AtomicLong outstanding, - AtomicLong requested, - UnicastSubject writer, - Subscriber child, - Runnable cancelAction - ) { - this.streamId = streamId; - this.requestThreshold = threshold; - this.requested = requested; - this.outstandingRequests = outstanding; - this.writer = writer; - this.child = child; - this.cancelAction = cancelAction; - this.requestStreamSubscription = null; - } - - public StreamInputSubscriber( - int streamId, - long threshold, - AtomicLong outstanding, - AtomicLong requested, - UnicastSubject writer, - Subscriber child, - AtomicReference requestStreamSubscription, - Runnable cancelAction - ) { - this.streamId = streamId; - this.requestThreshold = threshold; - this.requested = requested; - this.outstandingRequests = outstanding; - this.writer = writer; - this.child = child; - this.cancelAction = cancelAction; - this.requestStreamSubscription = requestStreamSubscription; - } - - @Override - public void onSubscribe(Subscription s) { - this.parentSubscription = s; - // no backpressure to transport (we will only receive what we've asked for already) - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Frame frame) { - FrameType type = frame.getType(); - // convert ERROR messages into terminal events - if (type == FrameType.NEXT_COMPLETE) { - terminated.set(true); - child.onNext(frame); - onComplete(); - cancel(); - } else if (type == FrameType.NEXT) { - child.onNext(frame); - long currentOutstanding = outstandingRequests.decrementAndGet(); - requestIfNecessary(streamId, requestThreshold, requested.get(), - currentOutstanding, writer, requested, outstandingRequests); - } else if (type == FrameType.REQUEST_N) { - if(requestStreamSubscription != null) { - Subscription s = requestStreamSubscription.get(); - if(s != null) { - s.request(Frame.RequestN.requestN(frame)); - } else { - // TODO can this ever be null? - System.err.println( - "ReactiveSocket Requester DEBUG: requestStreamSubscription is null"); - } - return; - } - // TODO should we do anything if we don't find the stream? emitting an error - // is risky as the responder could have terminated and cleaned up already - } else if (type == FrameType.COMPLETE) { - terminated.set(true); - onComplete(); - cancel(); - } else if (type == FrameType.ERROR) { - terminated.set(true); - final ByteBuffer byteBuffer = frame.getData(); - String errorMessage = getByteBufferAsString(byteBuffer); - onError(new RuntimeException(errorMessage)); - cancel(); - } else { - onError(new RuntimeException("Unexpected FrameType: " + frame.getType())); - cancel(); - } - } - - @Override - public void onError(Throwable t) { - terminated.set(true); - child.onError(t); - } - - @Override - public void onComplete() { - terminated.set(true); - child.onComplete(); - } - - private void cancel() { - cancelAction.run(); - } - } - - private static void requestIfNecessary( - int streamId, - long requestThreshold, - long currentN, - long currentOutstanding, - UnicastSubject writer, - AtomicLong requested, - AtomicLong outstanding - ) { - if(currentOutstanding <= requestThreshold) { - long batchSize = DEFAULT_BATCH - currentOutstanding; - final long requestN = currentN < batchSize ? currentN : batchSize; - - if (requestN > 0) { - // decrement as we request it - requested.addAndGet(-requestN); - // record how many we have requested - outstanding.addAndGet(requestN); - - writer.onNext(Frame.RequestN.from(streamId, (int)requestN)); - } - } - } - - private int nextStreamId() { - return streamCount += 2; // go by two since server is odd, client is even - } - - private void start(Completable onComplete) { - AtomicReference connectionSubscription = new AtomicReference<>(); - // get input from responder->requestor for responses - connection.getInput().subscribe(new Observer() { - public void onSubscribe(Disposable d) { - if (connectionSubscription.compareAndSet(null, d)) { - if(isServer) { - requesterStarted = true; - onComplete.success(); - } else { - // now that we are connected, send SETUP frame - // (asynchronously, other messages can continue being written after this) - Frame setupFrame = Frame.Setup.from( - setupPayload.getFlags(), - KEEPALIVE_INTERVAL_MS, - 0, - setupPayload.metadataMimeType(), - setupPayload.dataMimeType(), - setupPayload - ); - connection.addOutput(setupFrame, - new Completable() { - @Override - public void success() { - requesterStarted = true; - onComplete.success(); - } - - @Override - public void error(Throwable e) { - onComplete.error(e); - tearDown(e); - } - }); - - Publisher keepaliveTicker = - PublisherUtils.keepaliveTicker(KEEPALIVE_INTERVAL_MS, TimeUnit.MILLISECONDS); - connection.addOutput(keepaliveTicker, - new Completable() { - public void success() {} - - public void error(Throwable e) { - onComplete.error(e); - tearDown(e); - } - } - ); - } - } else { - // means we already were cancelled - d.dispose(); - onComplete.error(new CancelException("Connection Is Already Cancelled")); - } - } - - private void tearDown(Throwable e) { - onError(e); - } - - public void onNext(Frame frame) { - int streamId = frame.getStreamId(); - if (streamId == 0) { - if (FrameType.ERROR.equals(frame.getType())) { - final Throwable throwable = Exceptions.from(frame); - onError(throwable); - } else if (FrameType.LEASE.equals(frame.getType()) && honorLease) { - numberOfRemainingRequests = Frame.Lease.numberOfRequests(frame); - final long now = System.currentTimeMillis(); - final int ttl = Frame.Lease.ttl(frame); - if (ttl == Integer.MAX_VALUE) { - // Integer.MAX_VALUE represents infinity - ttlExpiration = Long.MAX_VALUE; - } else { - ttlExpiration = now + ttl; - } - } else if (FrameType.KEEPALIVE.equals(frame.getType())) { - timeOfLastKeepalive = System.currentTimeMillis(); - } else { - onError(new RuntimeException( - "Received unexpected message type on stream 0: " + frame.getType().name())); - } - } else { - UnicastSubject streamSubject; - synchronized (Requester.this) { - streamSubject = streamInputMap.get(streamId); - } - if (streamSubject == null) { - if (streamId <= streamCount) { - // receiving a frame after a given stream has been cancelled/completed, - // so ignore (cancellation is async so there is a race condition) - return; - } else { - // message for stream that has never existed, we have a problem with - // the overall connection and must tear down - if (frame.getType() == FrameType.ERROR) { - String errorMessage = getByteBufferAsString(frame.getData()); - onError(new RuntimeException( - "Received error for non-existent stream: " - + streamId + " Message: " + errorMessage)); - } else { - onError(new RuntimeException( - "Received message for non-existent stream: " + streamId)); - } - } - } else { - streamSubject.onNext(frame); - } - } - } - - public void onError(Throwable t) { - Collection> subjects; - synchronized (Requester.this) { - subjects = streamInputMap.values(); - } - subjects.forEach(subject -> subject.onError(t)); - // TODO: iterate over responder side and destroy world - errorStream.accept(t); - cancel(); - } - - public void onComplete() { - Collection> subjects; - synchronized (Requester.this) { - subjects = streamInputMap.values(); - } - subjects.forEach(UnicastSubject::onComplete); - cancel(); - } - - public void cancel() { // TODO this isn't used ... is it supposed to be? - if (!connectionSubscription.compareAndSet(null, CANCELLED)) { - // cancel the one that was there if we failed to set the sentinel - connectionSubscription.get().dispose(); - try { - connection.close(); - } catch (IOException e) { - errorStream.accept(e); - } - } - } - }); - } - - private static String getByteBufferAsString(ByteBuffer bb) { - final byte[] bytes = new byte[bb.capacity()]; - bb.get(bytes); - return new String(bytes, Charset.forName("UTF-8")); - } -} diff --git a/src/main/java/io/reactivesocket/internal/Responder.java b/src/main/java/io/reactivesocket/internal/Responder.java deleted file mode 100644 index bfcc62160..000000000 --- a/src/main/java/io/reactivesocket/internal/Responder.java +++ /dev/null @@ -1,892 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import io.reactivesocket.ConnectionSetupHandler; -import io.reactivesocket.ConnectionSetupPayload; -import io.reactivesocket.DuplexConnection; -import io.reactivesocket.Frame; -import io.reactivesocket.FrameType; -import io.reactivesocket.LeaseGovernor; -import io.reactivesocket.Payload; -import io.reactivesocket.ReactiveSocket; -import io.reactivesocket.RequestHandler; -import io.reactivesocket.exceptions.InvalidSetupException; -import io.reactivesocket.exceptions.RejectedException; -import io.reactivesocket.exceptions.SetupException; -import io.reactivesocket.internal.frame.FrameHeaderFlyweight; -import io.reactivesocket.internal.frame.SetupFrameFlyweight; -import io.reactivesocket.internal.rx.EmptyDisposable; -import io.reactivesocket.internal.rx.EmptySubscription; -import io.reactivesocket.rx.Completable; -import io.reactivesocket.rx.Disposable; -import io.reactivesocket.rx.Observer; -import org.agrona.collections.Int2ObjectHashMap; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.BiFunction; -import java.util.function.Consumer; - -/** - * Protocol implementation abstracted over a {@link DuplexConnection}. - *

- * Concrete implementations of {@link DuplexConnection} over TCP, WebSockets, - * Aeron, etc can be passed to this class for protocol handling. The request - * handlers passed in at creation will be invoked - * for each request over the connection. - */ -public class Responder { - private final DuplexConnection connection; - private final ConnectionSetupHandler connectionHandler; // for server - private final RequestHandler clientRequestHandler; // for client - private final Consumer errorStream; - private volatile LeaseGovernor leaseGovernor; - private long timeOfLastKeepalive; - private final Consumer setupCallback; - private final boolean isServer; - - private Responder( - boolean isServer, - DuplexConnection connection, - ConnectionSetupHandler connectionHandler, - RequestHandler requestHandler, - LeaseGovernor leaseGovernor, - Consumer errorStream, - Consumer setupCallback - ) { - this.isServer = isServer; - this.connection = connection; - this.connectionHandler = connectionHandler; - this.clientRequestHandler = requestHandler; - this.leaseGovernor = leaseGovernor; - this.errorStream = errorStream; - this.timeOfLastKeepalive = System.nanoTime(); - this.setupCallback = setupCallback; - } - - /** - * @param connectionHandler Handle connection setup and set up request - * handling. - * @param errorStream A {@link Consumer} which will receive - * all errors that occurs processing requests. - * This include fireAndForget which ONLY emit errors - * server-side via this mechanism. - * @return responder instance - */ - public static Responder createServerResponder( - DuplexConnection connection, - ConnectionSetupHandler connectionHandler, - LeaseGovernor leaseGovernor, - Consumer errorStream, - Completable responderCompletable, - Consumer setupCallback, - ReactiveSocket reactiveSocket - ) { - Responder responder = new Responder(true, connection, connectionHandler, null, - leaseGovernor, errorStream, setupCallback); - responder.start(responderCompletable, reactiveSocket); - return responder; - } - - public static Responder createServerResponder( - DuplexConnection connection, - ConnectionSetupHandler connectionHandler, - LeaseGovernor leaseGovernor, - Consumer errorStream, - Completable responderCompletable, - ReactiveSocket reactiveSocket - ) { - return createServerResponder(connection, connectionHandler, leaseGovernor, - errorStream, responderCompletable, s -> {}, reactiveSocket); - } - - public static Responder createClientResponder( - DuplexConnection connection, - RequestHandler requestHandler, - LeaseGovernor leaseGovernor, - Consumer errorStream, - Completable responderCompletable, - ReactiveSocket reactiveSocket - ) { - Responder responder = new Responder(false, connection, null, requestHandler, - leaseGovernor, errorStream, s -> {}); - responder.start(responderCompletable, reactiveSocket); - return responder; - } - - /** - * Send a LEASE frame immediately. Only way a LEASE is sent. Handled - * entirely by application logic. - * - * @param ttl of lease - * @param numberOfRequests of lease - */ - public void sendLease(final int ttl, final int numberOfRequests) { - Frame leaseFrame = Frame.Lease.from(ttl, numberOfRequests, Frame.NULL_BYTEBUFFER); - connection.addOutput(PublisherUtils.just(leaseFrame), new Completable() { - @Override - public void success() {} - - @Override - public void error(Throwable e) { - errorStream.accept(new RuntimeException("could not send lease ", e)); - } - }); - } - - /** - * Return time of last keepalive from client - * - * @return time from {@link System#nanoTime()} of last keepalive - */ - public long timeOfLastKeepalive() { - return timeOfLastKeepalive; - } - - private void start(final Completable responderCompletable, ReactiveSocket reactiveSocket) { - /* state of cancellation subjects during connection */ - final Int2ObjectHashMap cancellationSubscriptions = new Int2ObjectHashMap<>(); - /* streams in flight that can receive REQUEST_N messages */ - final Int2ObjectHashMap inFlight = new Int2ObjectHashMap<>(); - /* bidirectional channels */ - // TODO: should/can we make this optional so that it only gets allocated per connection if - // channels are used? - final Int2ObjectHashMap> channels = new Int2ObjectHashMap<>(); - - final AtomicBoolean childTerminated = new AtomicBoolean(false); - final AtomicReference transportSubscription = new AtomicReference<>(); - - // subscribe to transport to get Frames - connection.getInput().subscribe(new Observer() { - - @Override - public void onSubscribe(Disposable d) { - if (transportSubscription.compareAndSet(null, d)) { - // mark that we have completed setup - responderCompletable.success(); - } else { - // means we already were cancelled - d.dispose(); - } - } - - // null until after first Setup frame - volatile RequestHandler requestHandler = !isServer ? clientRequestHandler : null; - - @Override - public void onNext(Frame requestFrame) { - final int streamId = requestFrame.getStreamId(); - if (requestHandler == null) { // this will only happen when isServer==true - if (childTerminated.get()) { - // already terminated, but still receiving latent messages... - // ignore them while shutdown occurs - return; - } - if (requestFrame.getType().equals(FrameType.SETUP)) { - final ConnectionSetupPayload connectionSetupPayload = - ConnectionSetupPayload.create(requestFrame); - try { - int version = Frame.Setup.version(requestFrame); - if (version != SetupFrameFlyweight.CURRENT_VERSION) { - throw new SetupException("unsupported protocol version: " - + version); - } - - // accept setup for ReactiveSocket/Requester usage - setupCallback.accept(connectionSetupPayload); - // handle setup - requestHandler = connectionHandler.apply(connectionSetupPayload, reactiveSocket); - } catch (SetupException setupException) { - setupErrorAndTearDown(connection, setupException); - } catch (Throwable e) { - InvalidSetupException exc = new InvalidSetupException(e.getMessage()); - setupErrorAndTearDown(connection, exc); - } - - // the L bit set must wait until the application logic explicitly sends - // a LEASE. ConnectionSetupPlayload knows of bits being set. - if (connectionSetupPayload.willClientHonorLease()) { - leaseGovernor.register(Responder.this); - } else { - leaseGovernor = LeaseGovernor.UNLIMITED_LEASE_GOVERNOR; - } - - // TODO: handle keepalive logic here - } else { - setupErrorAndTearDown(connection, - new InvalidSetupException("Setup frame missing")); - } - } else { - Publisher responsePublisher = null; - if (leaseGovernor.accept(Responder.this, requestFrame)) { - try { - if (requestFrame.getType() == FrameType.REQUEST_RESPONSE) { - responsePublisher = handleRequestResponse( - requestFrame, requestHandler, cancellationSubscriptions); - } else if (requestFrame.getType() == FrameType.REQUEST_STREAM) { - responsePublisher = handleRequestStream( - requestFrame, requestHandler, cancellationSubscriptions, inFlight); - } else if (requestFrame.getType() == FrameType.FIRE_AND_FORGET) { - responsePublisher = handleFireAndForget( - requestFrame, requestHandler); - } else if (requestFrame.getType() == FrameType.REQUEST_SUBSCRIPTION) { - responsePublisher = handleRequestSubscription( - requestFrame, requestHandler, cancellationSubscriptions, inFlight); - } else if (requestFrame.getType() == FrameType.REQUEST_CHANNEL) { - responsePublisher = handleRequestChannel( - requestFrame, requestHandler, channels, - cancellationSubscriptions, inFlight); - } else if (requestFrame.getType() == FrameType.METADATA_PUSH) { - responsePublisher = handleMetadataPush( - requestFrame, requestHandler); - } else if (requestFrame.getType() == FrameType.CANCEL) { - Subscription s; - synchronized (Responder.this) { - s = cancellationSubscriptions.get(requestFrame.getStreamId()); - } - if (s != null) { - s.cancel(); - } - return; - } else if (requestFrame.getType() == FrameType.REQUEST_N) { - SubscriptionArbiter inFlightSubscription; - synchronized (Responder.this) { - inFlightSubscription = inFlight.get(requestFrame.getStreamId()); - } - if (inFlightSubscription != null) { - long requestN = Frame.RequestN.requestN(requestFrame); - inFlightSubscription.addApplicationRequest(requestN); - return; - } - // TODO should we do anything if we don't find the stream? - // emitting an error is risky as the responder could have - // terminated and cleaned up already - } else if (requestFrame.getType() == FrameType.KEEPALIVE) { - // this client is alive. - timeOfLastKeepalive = System.nanoTime(); - // echo back if flag set - if (Frame.Keepalive.hasRespondFlag(requestFrame)) { - Frame keepAliveFrame = Frame.Keepalive.from( - requestFrame.getData(), false); - responsePublisher = PublisherUtils.just(keepAliveFrame); - } else { - return; - } - } else if (requestFrame.getType() == FrameType.LEASE) { - // LEASE only concerns the Requester - } else { - IllegalStateException exc = new IllegalStateException( - "Unexpected prefix: " + requestFrame.getType()); - responsePublisher = PublisherUtils.errorFrame(streamId, exc); - } - } catch (Throwable e) { - // synchronous try/catch since we execute user functions - // in the handlers and they could throw - errorStream.accept( - new RuntimeException("Error in request handling.", e)); - // error message to user - responsePublisher = PublisherUtils.errorFrame( - streamId, new RuntimeException( - "Unhandled error processing request")); - } - } else { - RejectedException exception = new RejectedException("No associated lease"); - responsePublisher = PublisherUtils.errorFrame(streamId, exception); - } - - if (responsePublisher != null) { - connection.addOutput(responsePublisher, new Completable() { - @Override - public void success() { - // TODO Auto-generated method stub - } - - @Override - public void error(Throwable e) { - // TODO validate with unit tests - if (childTerminated.compareAndSet(false, true)) { - // TODO should we have typed RuntimeExceptions? - errorStream.accept(new RuntimeException("Error writing", e)); - cancel(); - } - } - }); - } - } - } - - private void setupErrorAndTearDown( - DuplexConnection connection, - SetupException setupException - ) { - // pass the ErrorFrame output, subscribe to write it, await - // onComplete and then tear down - final Frame frame = Frame.Error.from(0, setupException); - connection.addOutput(PublisherUtils.just(frame), - new Completable() { - @Override - public void success() { - tearDownWithError(setupException); - } - @Override - public void error(Throwable e) { - RuntimeException exc = new RuntimeException( - "Failure outputting SetupException", e); - tearDownWithError(exc); - } - }); - } - - private void tearDownWithError(Throwable se) { - // TODO unit test that this actually shuts things down - onError(new RuntimeException("Connection Setup Failure", se)); - } - - @Override - public void onError(Throwable t) { - // TODO validate with unit tests - if (childTerminated.compareAndSet(false, true)) { - errorStream.accept(t); - cancel(); - } - } - - @Override - public void onComplete() { - //TODO validate what is happening here - // this would mean the connection gracefully shut down, which is unexpected - if (childTerminated.compareAndSet(false, true)) { - cancel(); - } - } - - private void cancel() { - // child has cancelled (shutdown the connection or server) - // TODO validate with unit tests - if (!transportSubscription.compareAndSet(null, EmptyDisposable.EMPTY)) { - // cancel the one that was there if we failed to set the sentinel - transportSubscription.get().dispose(); - } - } - - }); - } - - public void shutdown() { - // TODO do something here - System.err.println("**** Responder.shutdown => this should actually do something"); - } - - private Publisher handleRequestResponse( - Frame requestFrame, - final RequestHandler requestHandler, - final Int2ObjectHashMap cancellationSubscriptions) { - - return (Subscriber child) -> { - Subscription s = new Subscription() { - - final AtomicBoolean started = new AtomicBoolean(false); - final AtomicReference parent = new AtomicReference<>(); - - @Override - public void request(long n) { - if (n > 0 && started.compareAndSet(false, true)) { - final int streamId = requestFrame.getStreamId(); - - Publisher responsePublisher = - requestHandler.handleRequestResponse(requestFrame); - responsePublisher.subscribe(new Subscriber() { - - // event emission is serialized so this doesn't need to be atomic - int count = 0; - - @Override - public void onSubscribe(Subscription s) { - if (parent.compareAndSet(null, s)) { - // only expect 1 value so we don't need REQUEST_N - s.request(Long.MAX_VALUE); - } else { - s.cancel(); - cleanup(); - } - } - - @Override - public void onNext(Payload v) { - if (++count > 1) { - IllegalStateException exc = new IllegalStateException( - "RequestResponse expects a single onNext"); - onError(exc); - } else { - Frame nextCompleteFrame = Frame.Response.from( - streamId, FrameType.RESPONSE, v.getMetadata(), v.getData(), FrameHeaderFlyweight.FLAGS_RESPONSE_C); - child.onNext(nextCompleteFrame); - } - } - - @Override - public void onError(Throwable t) { - child.onNext(Frame.Error.from(streamId, t)); - cleanup(); - } - - @Override - public void onComplete() { - if (count != 1) { - IllegalStateException exc = new IllegalStateException( - "RequestResponse expects a single onNext"); - onError(exc); - } else { - child.onComplete(); - cleanup(); - } - } - - }); - } - } - - @Override - public void cancel() { - if (!parent.compareAndSet(null, EmptySubscription.INSTANCE)) { - parent.get().cancel(); - cleanup(); - } - } - - private void cleanup() { - synchronized(Responder.this) { - cancellationSubscriptions.remove(requestFrame.getStreamId()); - } - } - - }; - synchronized(Responder.this) { - cancellationSubscriptions.put(requestFrame.getStreamId(), s); - } - child.onSubscribe(s); - }; - } - - private static BiFunction> - requestSubscriptionHandler = RequestHandler::handleSubscription; - private static BiFunction> - requestStreamHandler = RequestHandler::handleRequestStream; - - private Publisher handleRequestStream( - Frame requestFrame, - final RequestHandler requestHandler, - final Int2ObjectHashMap cancellationSubscriptions, - final Int2ObjectHashMap inFlight) { - return _handleRequestStream( - requestStreamHandler, - requestFrame, - requestHandler, - cancellationSubscriptions, - inFlight, - true - ); - } - - private Publisher handleRequestSubscription( - Frame requestFrame, - final RequestHandler requestHandler, - final Int2ObjectHashMap cancellationSubscriptions, - final Int2ObjectHashMap inFlight) { - return _handleRequestStream( - requestSubscriptionHandler, - requestFrame, - requestHandler, - cancellationSubscriptions, - inFlight, - false - ); - } - - /** - * Common logic for requestStream and requestSubscription - * - * @param handler - * @param requestFrame - * @param cancellationSubscriptions - * @param inFlight - * @param allowCompletion - * @return - */ - private Publisher _handleRequestStream( - BiFunction> handler, - Frame requestFrame, - final RequestHandler requestHandler, - final Int2ObjectHashMap cancellationSubscriptions, - final Int2ObjectHashMap inFlight, - final boolean allowCompletion) { - - return new Publisher() { - - @Override - public void subscribe(Subscriber child) { - Subscription s = new Subscription() { - - final AtomicBoolean started = new AtomicBoolean(false); - final AtomicReference parent = new AtomicReference<>(); - final SubscriptionArbiter arbiter = new SubscriptionArbiter(); - - @Override - public void request(long n) { - if(n <= 0) { - return; - } - if (started.compareAndSet(false, true)) { - arbiter.addTransportRequest(n); - final int streamId = requestFrame.getStreamId(); - - Publisher responses = - handler.apply(requestHandler, requestFrame); - responses.subscribe(new Subscriber() { - - @Override - public void onSubscribe(Subscription s) { - if (parent.compareAndSet(null, s)) { - inFlight.put(streamId, arbiter); - long n = Frame.Request.initialRequestN(requestFrame); - arbiter.addApplicationRequest(n); - arbiter.addApplicationProducer(s); - } else { - s.cancel(); - cleanup(); - } - } - - @Override - public void onNext(Payload v) { - try { - Frame nextFrame = Frame.Response.from( - streamId, FrameType.NEXT, v); - child.onNext(nextFrame); - } catch (Throwable e) { - onError(e); - } - } - - @Override - public void onError(Throwable t) { - child.onNext(Frame.Error.from(streamId, t)); - child.onComplete(); - cleanup(); - } - - @Override - public void onComplete() { - if (allowCompletion) { - Frame completeFrame = Frame.Response.from( - streamId, FrameType.COMPLETE); - child.onNext(completeFrame); - child.onComplete(); - cleanup(); - } else { - IllegalStateException exc = new IllegalStateException( - "Unexpected onComplete occurred on " + - "'requestSubscription'"); - onError(exc); - } - } - }); - } else { - arbiter.addTransportRequest(n); - } - } - - @Override - public void cancel() { - if (!parent.compareAndSet(null, EmptySubscription.INSTANCE)) { - parent.get().cancel(); - cleanup(); - } - } - - private void cleanup() { - synchronized(Responder.this) { - inFlight.remove(requestFrame.getStreamId()); - cancellationSubscriptions.remove(requestFrame.getStreamId()); - } - } - - }; - synchronized(Responder.this) { - cancellationSubscriptions.put(requestFrame.getStreamId(), s); - } - child.onSubscribe(s); - } - - }; - - } - - private Publisher handleFireAndForget( - Frame requestFrame, - final RequestHandler requestHandler - ) { - try { - requestHandler.handleFireAndForget(requestFrame).subscribe(completionSubscriber); - } catch (Throwable e) { - // we catch these errors here as we don't want anything propagating - // back to the user on fireAndForget - errorStream.accept(new RuntimeException("Error processing 'fireAndForget'", e)); - } - // we always treat this as if it immediately completes as we don't want - // errors passing back to the user - return PublisherUtils.empty(); - } - - private Publisher handleMetadataPush( - Frame requestFrame, - final RequestHandler requestHandler - ) { - try { - requestHandler.handleMetadataPush(requestFrame).subscribe(completionSubscriber); - } catch (Throwable e) { - // we catch these errors here as we don't want anything propagating - // back to the user on metadataPush - errorStream.accept(new RuntimeException("Error processing 'metadataPush'", e)); - } - // we always treat this as if it immediately completes as we don't want - // errors passing back to the user - return PublisherUtils.empty(); - } - - /** - * Reusable for each fireAndForget and metadataPush since no state is shared - * across invocations. It just passes through errors. - */ - private final Subscriber completionSubscriber = new Subscriber(){ - @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Void t) {} - - @Override public void onError(Throwable t) { - errorStream.accept(t); - } - - @Override public void onComplete() {} - }; - - private Publisher handleRequestChannel(Frame requestFrame, - RequestHandler requestHandler, - Int2ObjectHashMap> channels, - Int2ObjectHashMap cancellationSubscriptions, - Int2ObjectHashMap inFlight) { - - UnicastSubject channelSubject; - synchronized(Responder.this) { - channelSubject = channels.get(requestFrame.getStreamId()); - } - if (channelSubject == null) { - return new Publisher() { - - @Override - public void subscribe(Subscriber child) { - Subscription s = new Subscription() { - - final AtomicBoolean started = new AtomicBoolean(false); - final AtomicReference parent = new AtomicReference<>(); - final SubscriptionArbiter arbiter = new SubscriptionArbiter(); - - @Override - public void request(long n) { - if(n <= 0) { - return; - } - if (started.compareAndSet(false, true)) { - arbiter.addTransportRequest(n); - final int streamId = requestFrame.getStreamId(); - - // first request on this channel - UnicastSubject channelRequests = - UnicastSubject.create((s, rn) -> { - // after we are first subscribed to then send - // the initial frame - s.onNext(requestFrame); - if (rn.intValue() > 0) { - // initial requestN back to the requester (subtract 1 - // for the initial frame which was already sent) - child.onNext(Frame.RequestN.from(streamId, rn.intValue() - 1)); - } - }, r -> { - // requested - child.onNext(Frame.RequestN.from(streamId, r.intValue())); - }); - synchronized(Responder.this) { - if(channels.get(streamId) != null) { - // TODO validate that this correctly defends - // against this issue, this means we received a - // followup request that raced and that the requester - // didn't correct wait for REQUEST_N before sending - // more frames - RuntimeException exc = new RuntimeException( - "Requester sent more than 1 requestChannel " + - "frame before permitted."); - child.onNext(Frame.Error.from(streamId, exc)); - child.onComplete(); - cleanup(); - return; - } - channels.put(streamId, channelRequests); - } - - Publisher responses = requestHandler.handleChannel(requestFrame, channelRequests); - responses.subscribe(new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - if (parent.compareAndSet(null, s)) { - inFlight.put(streamId, arbiter); - long n = Frame.Request.initialRequestN(requestFrame); - arbiter.addApplicationRequest(n); - arbiter.addApplicationProducer(s); - } else { - s.cancel(); - cleanup(); - } - } - - @Override - public void onNext(Payload v) { - Frame nextFrame = Frame.Response.from( - streamId, FrameType.NEXT, v); - child.onNext(nextFrame); - } - - @Override - public void onError(Throwable t) { - child.onNext(Frame.Error.from(streamId, t)); - child.onComplete(); - cleanup(); - } - - @Override - public void onComplete() { - Frame completeFrame = Frame.Response.from( - streamId, FrameType.COMPLETE); - child.onNext(completeFrame); - child.onComplete(); - cleanup(); - } - }); - } else { - arbiter.addTransportRequest(n); - } - } - - @Override - public void cancel() { - if (!parent.compareAndSet(null, EmptySubscription.INSTANCE)) { - parent.get().cancel(); - cleanup(); - } - } - - private void cleanup() { - synchronized(Responder.this) { - inFlight.remove(requestFrame.getStreamId()); - cancellationSubscriptions.remove(requestFrame.getStreamId()); - } - } - - }; - synchronized(Responder.this) { - cancellationSubscriptions.put(requestFrame.getStreamId(), s); - } - child.onSubscribe(s); - } - - }; - - } else { - // send data to channel - if (channelSubject.isSubscribedTo()) { - if(Frame.Request.isRequestChannelComplete(requestFrame)) { - channelSubject.onComplete(); - } else { - // TODO this is ignoring requestN flow control (need to validate - // that this is legit because REQUEST_N across the wire is - // controlling it on the Requester side) - channelSubject.onNext(requestFrame); - } - // TODO should at least have an error message of some kind if the - // Requester disregarded it - return PublisherUtils.empty(); - } else { - // TODO should we use a BufferUntilSubscriber solution instead to - // handle time-gap issues like this? - // TODO validate with unit tests. - return PublisherUtils.errorFrame( - requestFrame.getStreamId(), new RuntimeException("Channel unavailable")); - } - } - } - - private static class SubscriptionArbiter { - private Subscription applicationProducer; - private long appRequested = 0; - private long transportRequested = 0; - private long requestedToProducer = 0; - - public void addApplicationRequest(long n) { - synchronized(this) { - appRequested += n; - } - tryRequest(); - } - - public void addApplicationProducer(Subscription s) { - synchronized(this) { - applicationProducer = s; - } - tryRequest(); - } - - public void addTransportRequest(long n) { - synchronized(this) { - transportRequested += n; - } - tryRequest(); - } - - private void tryRequest() { - long toRequest; - synchronized(this) { - if(applicationProducer == null) { - return; - } - long minToRequest = Math.min(appRequested, transportRequested); - toRequest = minToRequest - requestedToProducer; - requestedToProducer += toRequest; - } - if(toRequest > 0) { - applicationProducer.request(toRequest); - } - } - - } - -} diff --git a/src/main/java/io/reactivesocket/internal/UnicastSubject.java b/src/main/java/io/reactivesocket/internal/UnicastSubject.java deleted file mode 100644 index fa23c366f..000000000 --- a/src/main/java/io/reactivesocket/internal/UnicastSubject.java +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import java.util.function.BiConsumer; -import java.util.function.Consumer; - -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -/** - * Intended to ONLY support a single Subscriber. It will throw an exception if more than 1 subscribe occurs. - *

- * This differs from PublishSubject which allows multicasting. This is done for efficiency reasons. - *

- * This is NOT thread-safe. - */ -public final class UnicastSubject implements Subscriber, Publisher { - - private Subscriber s; - private final BiConsumer, Long> onConnect; - private final Consumer onRequest; - private boolean subscribedTo = false; - - public static UnicastSubject create() { - return new UnicastSubject<>(null, r -> {}); - } - - /** - * @param onConnect Called when first requestN > 0 occurs. - * @param onRequest Called for each requestN after the first one (which invokes onConnect) - * @return - */ - public static UnicastSubject create(BiConsumer, Long> onConnect, Consumer onRequest) { - return new UnicastSubject<>(onConnect, onRequest); - } - - /** - * @param onConnect Called when first requestN > 0 occurs. - * @return - */ - public static UnicastSubject create(BiConsumer, Long> onConnect) { - return new UnicastSubject<>(onConnect, r -> {}); - } - - private UnicastSubject(BiConsumer, Long> onConnect, Consumer onRequest) { - this.onConnect = onConnect; - this.onRequest = onRequest; - } - - @Override - public void onSubscribe(Subscription s) { - throw new IllegalStateException("This UnicastSubject does not support being used as a Subscriber to a Publisher"); - } - - @Override - public void onNext(T t) { - s.onNext(t); - } - - @Override - public void onError(Throwable t) { - s.onError(t); - } - - @Override - public void onComplete() { - s.onComplete(); - } - - @Override - public void subscribe(Subscriber s) { - if (this.s != null) { - s.onError(new IllegalStateException("Only single Subscriber supported")); - } else { - this.s = s; - this.s.onSubscribe(new Subscription() { - - boolean started = false; - - @Override - public void request(long n) { - if (n > 0) { - if (!started) { - started = true; - subscribedTo = true; - // now actually connected - if (onConnect != null) { - onConnect.accept(UnicastSubject.this, n); - } - } else { - onRequest.accept(n); - } - } - } - - @Override - public void cancel() { - // transport has shut us down - } - - }); - } - } - - public boolean isSubscribedTo() { - return subscribedTo; - } - -} diff --git a/src/main/java/io/reactivesocket/internal/frame/ByteBufferUtil.java b/src/main/java/io/reactivesocket/internal/frame/ByteBufferUtil.java deleted file mode 100644 index 48774aca4..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/ByteBufferUtil.java +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import java.nio.ByteBuffer; - -public class ByteBufferUtil -{ - - private ByteBufferUtil() {} - - /** - * Slice a portion of the {@link ByteBuffer} while preserving the buffers position and limit. - * - * NOTE: Missing functionaity from {@link ByteBuffer} - * - * @param byteBuffer to slice off of - * @param position to start slice at - * @param limit to slice to - * @return slice of byteBuffer with passed ByteBuffer preserved position and limit. - */ - public static ByteBuffer preservingSlice(final ByteBuffer byteBuffer, final int position, final int limit) - { - final int savedPosition = byteBuffer.position(); - final int savedLimit = byteBuffer.limit(); - - byteBuffer.limit(limit).position(position); - - final ByteBuffer result = byteBuffer.slice(); - - byteBuffer.limit(savedLimit).position(savedPosition); - return result; - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/ErrorFrameFlyweight.java b/src/main/java/io/reactivesocket/internal/frame/ErrorFrameFlyweight.java deleted file mode 100644 index fdb9f57e0..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/ErrorFrameFlyweight.java +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.FrameType; -import io.reactivesocket.exceptions.ApplicationException; -import io.reactivesocket.exceptions.CancelException; -import io.reactivesocket.exceptions.ConnectionException; -import io.reactivesocket.exceptions.InvalidRequestException; -import io.reactivesocket.exceptions.InvalidSetupException; -import io.reactivesocket.exceptions.RejectedException; -import io.reactivesocket.exceptions.RejectedSetupException; -import io.reactivesocket.exceptions.UnsupportedSetupException; -import org.agrona.BitUtil; -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -public class ErrorFrameFlyweight { - - private ErrorFrameFlyweight() {} - - // defined error codes - public static final int INVALID_SETUP = 0x0001; - public static final int UNSUPPORTED_SETUP = 0x0002; - public static final int REJECTED_SETUP = 0x0003; - public static final int CONNECTION_ERROR = 0x0011; - public static final int APPLICATION_ERROR = 0x0021; - public static final int REJECTED = 0x0022; - public static final int CANCEL = 0x0023; - public static final int INVALID = 0x0024; - - // relative to start of passed offset - private static final int ERROR_CODE_FIELD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - private static final int PAYLOAD_OFFSET = ERROR_CODE_FIELD_OFFSET + BitUtil.SIZE_OF_INT; - - public static int computeFrameLength( - final int metadataLength, - final int dataLength - ) { - int length = FrameHeaderFlyweight.computeFrameHeaderLength( - FrameType.ERROR, metadataLength, dataLength); - return length + BitUtil.SIZE_OF_INT; - } - - public static int encode( - final MutableDirectBuffer mutableDirectBuffer, - final int offset, - final int streamId, - final int errorCode, - final ByteBuffer metadata, - final ByteBuffer data - ) { - final int frameLength = computeFrameLength(metadata.remaining(), data.remaining()); - - int length = FrameHeaderFlyweight.encodeFrameHeader( - mutableDirectBuffer, offset, frameLength, 0, FrameType.ERROR, streamId); - - mutableDirectBuffer.putInt( - offset + ERROR_CODE_FIELD_OFFSET, errorCode, ByteOrder.BIG_ENDIAN); - length += BitUtil.SIZE_OF_INT; - - length += FrameHeaderFlyweight.encodeMetadata( - mutableDirectBuffer, offset, offset + length, metadata); - length += FrameHeaderFlyweight.encodeData(mutableDirectBuffer, offset + length, data); - - return length; - } - - public static int errorCodeFromException(Throwable ex) { - if (ex instanceof InvalidSetupException) { - return INVALID_SETUP; - } else if (ex instanceof UnsupportedSetupException) { - return UNSUPPORTED_SETUP; - } else if (ex instanceof RejectedSetupException) { - return REJECTED_SETUP; - } else if (ex instanceof ConnectionException) { - return CONNECTION_ERROR; - } else if (ex instanceof InvalidRequestException) { - return INVALID; - } else if (ex instanceof ApplicationException) { - return APPLICATION_ERROR; - } else if (ex instanceof RejectedException) { - return REJECTED; - } else if (ex instanceof CancelException) { - return CANCEL; - } - return INVALID; - } - - public static int errorCode(final DirectBuffer directBuffer, final int offset) { - return directBuffer.getInt(offset + ERROR_CODE_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static int payloadOffset(final DirectBuffer directBuffer, final int offset) { - return offset + FrameHeaderFlyweight.FRAME_HEADER_LENGTH + BitUtil.SIZE_OF_INT; - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/FrameHeaderFlyweight.java b/src/main/java/io/reactivesocket/internal/frame/FrameHeaderFlyweight.java deleted file mode 100644 index 06791073f..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/FrameHeaderFlyweight.java +++ /dev/null @@ -1,329 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.FrameType; -import org.agrona.BitUtil; -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -import static io.reactivesocket.internal.frame.ByteBufferUtil.preservingSlice; - -/** - * Per connection frame flyweight. - * - * Not the latest frame layout, but close. - * Does not include - * - fragmentation / reassembly - * - encode should remove Type param and have it as part of method name (1 encode per type?) - * - * Not thread-safe. Assumed to be used single-threaded - */ -public class FrameHeaderFlyweight -{ - - private FrameHeaderFlyweight() {} - - public static final ByteBuffer NULL_BYTEBUFFER = ByteBuffer.allocate(0); - - public static final int FRAME_HEADER_LENGTH; - - private static final boolean INCLUDE_FRAME_LENGTH = true; - - private static final int FRAME_LENGTH_FIELD_OFFSET; - private static final int TYPE_FIELD_OFFSET; - private static final int FLAGS_FIELD_OFFSET; - private static final int STREAM_ID_FIELD_OFFSET; - private static final int PAYLOAD_OFFSET; - - public static final int FLAGS_I = 0b1000_0000_0000_0000; - public static final int FLAGS_M = 0b0100_0000_0000_0000; - - public static final int FLAGS_KEEPALIVE_R = 0b0010_0000_0000_0000; - - public static final int FLAGS_RESPONSE_F = 0b0010_0000_0000_0000; - public static final int FLAGS_RESPONSE_C = 0b0001_0000_0000_0000; - - public static final int FLAGS_REQUEST_CHANNEL_F = 0b0010_0000_0000_0000; - - static - { - if (INCLUDE_FRAME_LENGTH) - { - FRAME_LENGTH_FIELD_OFFSET = 0; - } - else - { - FRAME_LENGTH_FIELD_OFFSET = -BitUtil.SIZE_OF_INT; - } - - TYPE_FIELD_OFFSET = FRAME_LENGTH_FIELD_OFFSET + BitUtil.SIZE_OF_INT; - FLAGS_FIELD_OFFSET = TYPE_FIELD_OFFSET + BitUtil.SIZE_OF_SHORT; - STREAM_ID_FIELD_OFFSET = FLAGS_FIELD_OFFSET + BitUtil.SIZE_OF_SHORT; - PAYLOAD_OFFSET = STREAM_ID_FIELD_OFFSET + BitUtil.SIZE_OF_INT; - - FRAME_HEADER_LENGTH = PAYLOAD_OFFSET; - } - - public static int computeFrameHeaderLength(final FrameType frameType, int metadataLength, final int dataLength) - { - return PAYLOAD_OFFSET + computeMetadataLength(metadataLength) + dataLength; - } - - public static int encodeFrameHeader( - final MutableDirectBuffer mutableDirectBuffer, - final int offset, - final int frameLength, - final int flags, - final FrameType frameType, - final int streamId) - { - if (INCLUDE_FRAME_LENGTH) - { - mutableDirectBuffer.putInt(offset + FRAME_LENGTH_FIELD_OFFSET, frameLength, ByteOrder.BIG_ENDIAN); - } - - mutableDirectBuffer.putShort(offset + TYPE_FIELD_OFFSET, (short) frameType.getEncodedType(), ByteOrder.BIG_ENDIAN); - mutableDirectBuffer.putShort(offset + FLAGS_FIELD_OFFSET, (short) flags, ByteOrder.BIG_ENDIAN); - mutableDirectBuffer.putInt(offset + STREAM_ID_FIELD_OFFSET, streamId, ByteOrder.BIG_ENDIAN); - - return FRAME_HEADER_LENGTH; - } - - public static int encodeMetadata( - final MutableDirectBuffer mutableDirectBuffer, - final int frameHeaderStartOffset, - final int metadataOffset, - final ByteBuffer metadata) - { - int length = 0; - final int metadataLength = metadata.remaining(); - - if (0 < metadataLength) - { - int flags = mutableDirectBuffer.getShort(frameHeaderStartOffset + FLAGS_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - flags |= FLAGS_M; - mutableDirectBuffer.putShort(frameHeaderStartOffset + FLAGS_FIELD_OFFSET, (short)flags, ByteOrder.BIG_ENDIAN); - mutableDirectBuffer.putInt(metadataOffset, metadata.capacity() + BitUtil.SIZE_OF_INT, ByteOrder.BIG_ENDIAN); - length += BitUtil.SIZE_OF_INT; - mutableDirectBuffer.putBytes(metadataOffset + length, metadata, metadataLength); - length += metadataLength; - } - - return length; - } - - public static int encodeData( - final MutableDirectBuffer mutableDirectBuffer, - final int dataOffset, - final ByteBuffer data) - { - int length = 0; - final int dataLength = data.remaining(); - - if (0 < data.capacity()) - { - mutableDirectBuffer.putBytes(dataOffset, data, dataLength); - length += dataLength; - } - - return length; - } - - // only used for types simple enough that they don't have their own FrameFlyweights - public static int encode( - final MutableDirectBuffer mutableDirectBuffer, - final int offset, - final int streamId, - int flags, - final FrameType frameType, - final ByteBuffer metadata, - final ByteBuffer data) - { - final int frameLength = computeFrameHeaderLength(frameType, metadata.remaining(), data.remaining()); - - final FrameType outFrameType; - - switch (frameType) - { - case COMPLETE: - outFrameType = FrameType.RESPONSE; - flags |= FLAGS_RESPONSE_C; - break; - case NEXT: - outFrameType = FrameType.RESPONSE; - break; - default: - outFrameType = frameType; - break; - } - - int length = FrameHeaderFlyweight.encodeFrameHeader(mutableDirectBuffer, offset, frameLength, flags, outFrameType, streamId); - - length += FrameHeaderFlyweight.encodeMetadata(mutableDirectBuffer, offset, offset + length, metadata); - length += FrameHeaderFlyweight.encodeData(mutableDirectBuffer, offset + length, data); - - return length; - } - - public static int flags(final DirectBuffer directBuffer, final int offset) - { - return directBuffer.getShort(offset + FLAGS_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static FrameType frameType(final DirectBuffer directBuffer, final int offset) - { - FrameType result = FrameType.from(directBuffer.getShort(offset + TYPE_FIELD_OFFSET, ByteOrder.BIG_ENDIAN)); - - if (FrameType.RESPONSE == result) - { - final int flags = flags(directBuffer, offset); - final int dataLength = dataLength(directBuffer, offset, 0); - - if (FLAGS_RESPONSE_C == (flags & FLAGS_RESPONSE_C) && 0 < dataLength) - { - result = FrameType.NEXT_COMPLETE; - } - else if (FLAGS_RESPONSE_C == (flags & FLAGS_RESPONSE_C)) - { - result = FrameType.COMPLETE; - } - else - { - result = FrameType.NEXT; - } - } - - return result; - } - - public static int streamId(final DirectBuffer directBuffer, final int offset) - { - return directBuffer.getInt(offset + STREAM_ID_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static ByteBuffer sliceFrameData(final DirectBuffer directBuffer, final int offset, final int length) - { - final int dataLength = dataLength(directBuffer, offset, length); - final int dataOffset = dataOffset(directBuffer, offset); - ByteBuffer result = NULL_BYTEBUFFER; - - if (0 < dataLength) - { - result = preservingSlice(directBuffer.byteBuffer(), dataOffset, dataOffset + dataLength); - } - - return result; - } - - public static ByteBuffer sliceFrameMetadata(final DirectBuffer directBuffer, final int offset, final int length) - { - final int metadataLength = Math.max(0, metadataFieldLength(directBuffer, offset) - BitUtil.SIZE_OF_INT); - final int metadataOffset = metadataOffset(directBuffer, offset) + BitUtil.SIZE_OF_INT; - ByteBuffer result = NULL_BYTEBUFFER; - - if (0 < metadataLength) - { - result = preservingSlice(directBuffer.byteBuffer(), metadataOffset, metadataOffset + metadataLength); - } - - return result; - } - - private static int frameLength(final DirectBuffer directBuffer, final int offset, final int externalFrameLength) - { - int frameLength = externalFrameLength; - - if (INCLUDE_FRAME_LENGTH) - { - frameLength = directBuffer.getInt(offset + FRAME_LENGTH_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - return frameLength; - } - - private static int computeMetadataLength(final int metadataPayloadLength) - { - return metadataPayloadLength + ((0 == metadataPayloadLength) ? 0 : BitUtil.SIZE_OF_INT); - } - - private static int metadataFieldLength(final DirectBuffer directBuffer, final int offset) - { - int metadataLength = 0; - - if (FLAGS_M == (FLAGS_M & directBuffer.getShort(offset + FLAGS_FIELD_OFFSET, ByteOrder.BIG_ENDIAN))) - { - metadataLength = directBuffer.getInt(metadataOffset(directBuffer, offset), ByteOrder.BIG_ENDIAN) & 0xFFFFFF; - } - - return metadataLength; - } - - private static int dataLength(final DirectBuffer directBuffer, final int offset, final int externalLength) - { - final int frameLength = frameLength(directBuffer, offset, externalLength); - final int metadataLength = metadataFieldLength(directBuffer, offset); - - return offset + frameLength - metadataLength - payloadOffset(directBuffer, offset); - } - - private static int payloadOffset(final DirectBuffer directBuffer, final int offset) - { - final FrameType frameType = FrameType.from(directBuffer.getShort(offset + TYPE_FIELD_OFFSET, ByteOrder.BIG_ENDIAN)); - int result = offset + PAYLOAD_OFFSET; - - switch (frameType) - { - case SETUP: - result = SetupFrameFlyweight.payloadOffset(directBuffer, offset); - break; - case ERROR: - result = ErrorFrameFlyweight.payloadOffset(directBuffer, offset); - break; - case LEASE: - result = LeaseFrameFlyweight.payloadOffset(directBuffer, offset); - break; - case KEEPALIVE: - result = KeepaliveFrameFlyweight.payloadOffset(directBuffer, offset); - break; - case REQUEST_RESPONSE: - case FIRE_AND_FORGET: - case REQUEST_STREAM: - case REQUEST_SUBSCRIPTION: - case REQUEST_CHANNEL: - result = RequestFrameFlyweight.payloadOffset(frameType, directBuffer, offset); - break; - case REQUEST_N: - result = RequestNFrameFlyweight.payloadOffset(directBuffer, offset); - break; - } - - return result; - } - - private static int metadataOffset(final DirectBuffer directBuffer, final int offset) - { - return payloadOffset(directBuffer, offset); - } - - private static int dataOffset(final DirectBuffer directBuffer, final int offset) - { - return payloadOffset(directBuffer, offset) + metadataFieldLength(directBuffer, offset); - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/FramePool.java b/src/main/java/io/reactivesocket/internal/frame/FramePool.java deleted file mode 100644 index 9b37a643d..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/FramePool.java +++ /dev/null @@ -1,38 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.Frame; -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteBuffer; - -public interface FramePool -{ - Frame acquireFrame(final int size); - - Frame acquireFrame(final ByteBuffer byteBuffer); - - Frame acquireFrame(final MutableDirectBuffer mutableDirectBuffer); - - MutableDirectBuffer acquireMutableDirectBuffer(final int size); - - MutableDirectBuffer acquireMutableDirectBuffer(final ByteBuffer byteBuffer); - - void release(final Frame frame); - - void release(final MutableDirectBuffer mutableDirectBuffer); -} diff --git a/src/main/java/io/reactivesocket/internal/frame/KeepaliveFrameFlyweight.java b/src/main/java/io/reactivesocket/internal/frame/KeepaliveFrameFlyweight.java deleted file mode 100644 index a25350539..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/KeepaliveFrameFlyweight.java +++ /dev/null @@ -1,53 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.FrameType; -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteBuffer; - -public class KeepaliveFrameFlyweight -{ - private KeepaliveFrameFlyweight() {} - - private static final int PAYLOAD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - - public static int computeFrameLength(final int dataLength) - { - return FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.SETUP, 0, dataLength); - } - - public static int encode( - final MutableDirectBuffer mutableDirectBuffer, - final int offset, - final ByteBuffer data) - { - final int frameLength = computeFrameLength(data.remaining()); - - int length = FrameHeaderFlyweight.encodeFrameHeader(mutableDirectBuffer, offset, frameLength, 0, FrameType.KEEPALIVE, 0); - - length += FrameHeaderFlyweight.encodeData(mutableDirectBuffer, offset + length, data); - - return length; - } - - public static int payloadOffset(final DirectBuffer directBuffer, final int offset) - { - return offset + PAYLOAD_OFFSET; - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/LeaseFrameFlyweight.java b/src/main/java/io/reactivesocket/internal/frame/LeaseFrameFlyweight.java deleted file mode 100644 index ceed7ee62..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/LeaseFrameFlyweight.java +++ /dev/null @@ -1,76 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.FrameType; -import org.agrona.BitUtil; -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -public class LeaseFrameFlyweight -{ - private LeaseFrameFlyweight() {} - - // relative to start of passed offset - private static final int TTL_FIELD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - private static final int NUM_REQUESTS_FIELD_OFFSET = TTL_FIELD_OFFSET + BitUtil.SIZE_OF_INT; - private static final int PAYLOAD_OFFSET = NUM_REQUESTS_FIELD_OFFSET + BitUtil.SIZE_OF_INT; - - public static int computeFrameLength(final int metadataLength) - { - int length = FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.SETUP, metadataLength, 0); - - return length + BitUtil.SIZE_OF_INT * 2; - } - - public static int encode( - final MutableDirectBuffer mutableDirectBuffer, - final int offset, - final int ttl, - final int numRequests, - final ByteBuffer metadata) - { - final int frameLength = computeFrameLength(metadata.remaining()); - - int length = FrameHeaderFlyweight.encodeFrameHeader(mutableDirectBuffer, offset, frameLength, 0, FrameType.LEASE, 0); - - mutableDirectBuffer.putInt(offset + TTL_FIELD_OFFSET, ttl, ByteOrder.BIG_ENDIAN); - mutableDirectBuffer.putInt(offset + NUM_REQUESTS_FIELD_OFFSET, numRequests, ByteOrder.BIG_ENDIAN); - - length += BitUtil.SIZE_OF_INT * 2; - length += FrameHeaderFlyweight.encodeMetadata(mutableDirectBuffer, offset, offset + length, metadata); - - return length; - } - - public static int ttl(final DirectBuffer directBuffer, final int offset) - { - return directBuffer.getInt(offset + TTL_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static int numRequests(final DirectBuffer directBuffer, final int offset) - { - return directBuffer.getInt(offset + NUM_REQUESTS_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static int payloadOffset(final DirectBuffer directBuffer, final int offset) - { - return offset + PAYLOAD_OFFSET; - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/PayloadBuilder.java b/src/main/java/io/reactivesocket/internal/frame/PayloadBuilder.java deleted file mode 100644 index 4db09e867..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/PayloadBuilder.java +++ /dev/null @@ -1,138 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.Frame; -import io.reactivesocket.Payload; -import org.agrona.BitUtil; -import org.agrona.MutableDirectBuffer; -import org.agrona.concurrent.UnsafeBuffer; - -import java.nio.ByteBuffer; -import java.util.Arrays; - -/** - * Builder for appending buffers that grows dataCapacity as necessary. Similar to Aeron's PayloadBuilder. - */ -public class PayloadBuilder -{ - public static final int INITIAL_CAPACITY = Math.max(Frame.DATA_MTU, Frame.METADATA_MTU); - - private final MutableDirectBuffer dataMutableDirectBuffer; - private final MutableDirectBuffer metadataMutableDirectBuffer; - - private byte[] dataBuffer; - private byte[] metadataBuffer; - private int dataLimit = 0; - private int metadataLimit = 0; - private int dataCapacity; - private int metadataCapacity; - - public PayloadBuilder() - { - dataCapacity = BitUtil.findNextPositivePowerOfTwo(INITIAL_CAPACITY); - metadataCapacity = BitUtil.findNextPositivePowerOfTwo(INITIAL_CAPACITY); - dataBuffer = new byte[dataCapacity]; - metadataBuffer = new byte[metadataCapacity]; - dataMutableDirectBuffer = new UnsafeBuffer(dataBuffer); - metadataMutableDirectBuffer = new UnsafeBuffer(metadataBuffer); - } - - public Payload payload() - { - return new Payload() - { - public ByteBuffer getData() - { - return ByteBuffer.wrap(dataBuffer, 0, dataLimit); - } - - public ByteBuffer getMetadata() - { - return ByteBuffer.wrap(metadataBuffer, 0, metadataLimit); - } - }; - } - - public void append(final Payload payload) - { - final ByteBuffer payloadData = payload.getData(); - final ByteBuffer payloadMetadata = payload.getMetadata(); - final int dataLength = payloadData.remaining(); - final int metadataLength = payloadMetadata.remaining(); - - ensureDataCapacity(dataLength); - ensureMetadataCapacity(metadataLength); - - dataMutableDirectBuffer.putBytes(dataLimit, payloadData, payloadData.capacity()); - dataLimit += dataLength; - metadataMutableDirectBuffer.putBytes(metadataLimit, payloadMetadata, payloadMetadata.capacity()); - metadataLimit += metadataLength; - } - - private void ensureDataCapacity(final int additionalCapacity) - { - final int requiredCapacity = dataLimit + additionalCapacity; - - if (requiredCapacity < 0) - { - final String s = String.format("Insufficient data capacity: dataLimit=%d additional=%d", dataLimit, additionalCapacity); - throw new IllegalStateException(s); - } - - if (requiredCapacity > dataCapacity) - { - final int newCapacity = findSuitableCapacity(dataCapacity, requiredCapacity); - final byte[] newBuffer = Arrays.copyOf(dataBuffer, newCapacity); - - dataCapacity = newCapacity; - dataBuffer = newBuffer; - dataMutableDirectBuffer.wrap(newBuffer); - } - } - - private void ensureMetadataCapacity(final int additionalCapacity) - { - final int requiredCapacity = metadataLimit + additionalCapacity; - - if (requiredCapacity < 0) - { - final String s = String.format("Insufficient metadata capacity: metadataLimit=%d additional=%d", metadataLimit, additionalCapacity); - throw new IllegalStateException(s); - } - - if (requiredCapacity > metadataCapacity) - { - final int newCapacity = findSuitableCapacity(metadataCapacity, requiredCapacity); - final byte[] newBuffer = Arrays.copyOf(metadataBuffer, newCapacity); - - metadataCapacity = newCapacity; - metadataBuffer = newBuffer; - metadataMutableDirectBuffer.wrap(newBuffer); - } - } - - private static int findSuitableCapacity(int capacity, final int requiredCapacity) - { - do - { - capacity <<= 1; - } - while (capacity < requiredCapacity); - - return capacity; - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/PayloadFragmenter.java b/src/main/java/io/reactivesocket/internal/frame/PayloadFragmenter.java deleted file mode 100644 index 98f3aba29..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/PayloadFragmenter.java +++ /dev/null @@ -1,149 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.Frame; -import io.reactivesocket.FrameType; -import io.reactivesocket.Payload; - -import java.nio.ByteBuffer; -import java.util.Iterator; - -/** - * Stateful iterator that can be re-used. - * - * Not thread-safe - */ -public class PayloadFragmenter implements Iterable, Iterator -{ - private enum Type - { - RESPONSE, RESPONSE_COMPLETE, REQUEST_CHANNEL - } - - private final int metadataMtu; - private final int dataMtu; - private ByteBuffer metadata; - private ByteBuffer data; - private Type type; - private int metadataOffset; - private int dataOffset; - private int streamId; - private int initialRequestN; - - public PayloadFragmenter(final int metadataMtu, final int dataMtu) - { - this.metadataMtu = metadataMtu; - this.dataMtu = dataMtu; - } - - public void resetForResponse(final int streamId, final Payload payload) - { - reset(streamId, payload); - type = Type.RESPONSE; - } - - public void resetForResponseComplete(final int streamId, final Payload payload) - { - reset(streamId, payload); - type = Type.RESPONSE_COMPLETE; - } - - public void resetForRequestChannel(final int streamId, final Payload payload, final int initialRequestN) - { - reset(streamId, payload); - type = Type.REQUEST_CHANNEL; - this.initialRequestN = initialRequestN; - } - - public static boolean requiresFragmenting(final int metadataMtu, final int dataMtu, final Payload payload) - { - final ByteBuffer metadata = payload.getMetadata(); - final ByteBuffer data = payload.getData(); - - return metadata.remaining() > metadataMtu || data.remaining() > dataMtu; - } - - public Iterator iterator() - { - return this; - } - - public boolean hasNext() - { - return dataOffset < data.capacity() || metadataOffset < metadata.remaining(); - } - - public Frame next() - { - final int metadataLength = Math.min(metadataMtu, metadata.remaining() - metadataOffset); - final int dataLength = Math.min(dataMtu, data.remaining() - dataOffset); - - Frame result = null; - - final ByteBuffer metadataBuffer = metadataLength > 0 ? - ByteBufferUtil.preservingSlice(metadata, metadataOffset, metadataOffset + metadataLength) : Frame.NULL_BYTEBUFFER; - - final ByteBuffer dataBuffer = dataLength > 0 ? - ByteBufferUtil.preservingSlice(data, dataOffset, dataOffset + dataLength) : Frame.NULL_BYTEBUFFER; - - metadataOffset += metadataLength; - dataOffset += dataLength; - - final boolean isMoreFollowing = metadataOffset < metadata.remaining() || dataOffset < data.remaining(); - int flags = 0; - - if (Type.RESPONSE == type) - { - if (isMoreFollowing) - { - flags |= FrameHeaderFlyweight.FLAGS_RESPONSE_F; - } - - result = Frame.Response.from(streamId, FrameType.NEXT, metadataBuffer, dataBuffer, flags); - } - if (Type.RESPONSE_COMPLETE == type) - { - if (isMoreFollowing) - { - flags |= FrameHeaderFlyweight.FLAGS_RESPONSE_F; - } - - result = Frame.Response.from(streamId, FrameType.NEXT_COMPLETE, metadataBuffer, dataBuffer, flags); - } - else if (Type.REQUEST_CHANNEL == type) - { - if (isMoreFollowing) - { - flags |= FrameHeaderFlyweight.FLAGS_REQUEST_CHANNEL_F; - } - - result = Frame.Request.from(streamId, FrameType.REQUEST_CHANNEL, metadataBuffer, dataBuffer, initialRequestN, flags); - initialRequestN = 0; - } - - return result; - } - - private void reset(final int streamId, final Payload payload) - { - data = payload.getData(); - metadata = payload.getMetadata(); - metadataOffset = 0; - dataOffset = 0; - this.streamId = streamId; - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/PayloadReassembler.java b/src/main/java/io/reactivesocket/internal/frame/PayloadReassembler.java deleted file mode 100644 index 6d5212028..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/PayloadReassembler.java +++ /dev/null @@ -1,91 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.Frame; -import io.reactivesocket.Payload; -import org.agrona.collections.Int2ObjectHashMap; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - - -public class PayloadReassembler implements Subscriber -{ - private final Subscriber child; - private final Int2ObjectHashMap payloadByStreamId = new Int2ObjectHashMap<>(); - - private PayloadReassembler(final Subscriber child) - { - this.child = child; - } - - public static PayloadReassembler with(final Subscriber child) - { - return new PayloadReassembler(child); - } - - public void resetStream(final int streamId) - { - payloadByStreamId.remove(streamId); - } - - public void onSubscribe(Subscription s) - { - // reset - } - - public void onNext(Frame frame) - { - // if frame has no F bit and no waiting payload, then simply pass on - final int streamId = frame.getStreamId(); - PayloadBuilder payloadBuilder = payloadByStreamId.get(streamId); - - if (FrameHeaderFlyweight.FLAGS_RESPONSE_F != (frame.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)) - { - Payload deliveryPayload = frame; - - // terminal frame - if (null != payloadBuilder) - { - payloadBuilder.append(frame); - deliveryPayload = payloadBuilder.payload(); - payloadByStreamId.remove(streamId); - } - - child.onNext(deliveryPayload); - } - else - { - if (null == payloadBuilder) - { - payloadBuilder = new PayloadBuilder(); - payloadByStreamId.put(streamId, payloadBuilder); - } - - payloadBuilder.append(frame); - } - } - - public void onError(Throwable t) - { - // reset and pass through - } - - public void onComplete() - { - // reset and pass through - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/RequestFrameFlyweight.java b/src/main/java/io/reactivesocket/internal/frame/RequestFrameFlyweight.java deleted file mode 100644 index 6db2ba997..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/RequestFrameFlyweight.java +++ /dev/null @@ -1,108 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.FrameType; -import org.agrona.BitUtil; -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -public class RequestFrameFlyweight -{ - - private RequestFrameFlyweight() {} - - public static final int FLAGS_REQUEST_CHANNEL_C = 0b0001_0000_0000_0000; - public static final int FLAGS_REQUEST_CHANNEL_N = 0b0000_1000_0000_0000; - - // relative to start of passed offset - private static final int INITIAL_REQUEST_N_FIELD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - - public static int computeFrameLength(final FrameType type, final int metadataLength, final int dataLength) - { - int length = FrameHeaderFlyweight.computeFrameHeaderLength(type, metadataLength, dataLength); - - if (type.hasInitialRequestN()) - { - length += BitUtil.SIZE_OF_INT; - } - - return length; - } - - public static int encode( - final MutableDirectBuffer mutableDirectBuffer, - final int offset, - final int streamId, - int flags, - final FrameType type, - final int initialRequestN, - final ByteBuffer metadata, - final ByteBuffer data) - { - final int frameLength = computeFrameLength(type, metadata.remaining(), data.remaining()); - - flags |= FLAGS_REQUEST_CHANNEL_N; - int length = FrameHeaderFlyweight.encodeFrameHeader(mutableDirectBuffer, offset, frameLength, flags, type, streamId); - - mutableDirectBuffer.putInt(offset + INITIAL_REQUEST_N_FIELD_OFFSET, initialRequestN, ByteOrder.BIG_ENDIAN); - length += BitUtil.SIZE_OF_INT; - - length += FrameHeaderFlyweight.encodeMetadata(mutableDirectBuffer, offset, offset + length, metadata); - length += FrameHeaderFlyweight.encodeData(mutableDirectBuffer, offset + length, data); - - return length; - } - - public static int encode( - final MutableDirectBuffer mutableDirectBuffer, - final int offset, - final int streamId, - final int flags, - final FrameType type, - final ByteBuffer metadata, - final ByteBuffer data) - { - final int frameLength = computeFrameLength(type, metadata.remaining(), data.remaining()); - - int length = FrameHeaderFlyweight.encodeFrameHeader(mutableDirectBuffer, offset, frameLength, flags, type, streamId); - - length += FrameHeaderFlyweight.encodeMetadata(mutableDirectBuffer, offset, offset + length, metadata); - length += FrameHeaderFlyweight.encodeData(mutableDirectBuffer, offset + length, data); - - return length; - } - - public static int initialRequestN(final DirectBuffer directBuffer, final int offset) - { - return directBuffer.getInt(offset + INITIAL_REQUEST_N_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static int payloadOffset(final FrameType type, final DirectBuffer directBuffer, final int offset) - { - int result = offset + FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - - if (type.hasInitialRequestN()) - { - result += BitUtil.SIZE_OF_INT; - } - - return result; - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/RequestNFrameFlyweight.java b/src/main/java/io/reactivesocket/internal/frame/RequestNFrameFlyweight.java deleted file mode 100644 index 7ff679a90..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/RequestNFrameFlyweight.java +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.FrameType; -import org.agrona.BitUtil; -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteOrder; - -public class RequestNFrameFlyweight -{ - private RequestNFrameFlyweight() {} - - // relative to start of passed offset - private static final int REQUEST_N_FIELD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - - public static int computeFrameLength() - { - int length = FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.REQUEST_N, 0, 0); - - return length + BitUtil.SIZE_OF_INT; - } - - public static int encode( - final MutableDirectBuffer mutableDirectBuffer, - final int offset, - final int streamId, - final int requestN) - { - final int frameLength = computeFrameLength(); - - int length = FrameHeaderFlyweight.encodeFrameHeader(mutableDirectBuffer, offset, frameLength, 0, FrameType.REQUEST_N, streamId); - - mutableDirectBuffer.putInt(offset + REQUEST_N_FIELD_OFFSET, requestN, ByteOrder.BIG_ENDIAN); - - return length + BitUtil.SIZE_OF_INT; - } - - public static int requestN(final DirectBuffer directBuffer, final int offset) - { - return directBuffer.getInt(offset + REQUEST_N_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static int payloadOffset(final DirectBuffer directBuffer, final int offset) - { - return offset + FrameHeaderFlyweight.FRAME_HEADER_LENGTH + BitUtil.SIZE_OF_INT; - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/SetupFrameFlyweight.java b/src/main/java/io/reactivesocket/internal/frame/SetupFrameFlyweight.java deleted file mode 100644 index db9df1d79..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/SetupFrameFlyweight.java +++ /dev/null @@ -1,156 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.FrameType; -import org.agrona.BitUtil; -import org.agrona.DirectBuffer; -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.charset.Charset; - -public class SetupFrameFlyweight -{ - private SetupFrameFlyweight() {} - - public static final int FLAGS_WILL_HONOR_LEASE = 0b0010_0000; - public static final int FLAGS_STRICT_INTERPRETATION = 0b0001_0000; - - public static final byte CURRENT_VERSION = 0; - - // relative to start of passed offset - private static final int VERSION_FIELD_OFFSET = FrameHeaderFlyweight.FRAME_HEADER_LENGTH; - private static final int KEEPALIVE_INTERVAL_FIELD_OFFSET = VERSION_FIELD_OFFSET + BitUtil.SIZE_OF_INT; - private static final int MAX_LIFETIME_FIELD_OFFSET = KEEPALIVE_INTERVAL_FIELD_OFFSET + BitUtil.SIZE_OF_INT; - private static final int METADATA_MIME_TYPE_LENGTH_OFFSET = MAX_LIFETIME_FIELD_OFFSET + BitUtil.SIZE_OF_INT; - - public static int computeFrameLength( - final String metadataMimeType, - final String dataMimeType, - final int metadataLength, - final int dataLength) - { - int length = FrameHeaderFlyweight.computeFrameHeaderLength(FrameType.SETUP, metadataLength, dataLength); - - length += BitUtil.SIZE_OF_INT * 3; - length += 1 + metadataMimeType.length(); - length += 1 + dataMimeType.length(); - - return length; - } - - public static int encode( - final MutableDirectBuffer mutableDirectBuffer, - final int offset, - int flags, - final int keepaliveInterval, - final int maxLifetime, - final String metadataMimeType, - final String dataMimeType, - final ByteBuffer metadata, - final ByteBuffer data) - { - final int frameLength = computeFrameLength(metadataMimeType, dataMimeType, metadata.remaining(), data.remaining()); - - int length = FrameHeaderFlyweight.encodeFrameHeader(mutableDirectBuffer, offset, frameLength, flags, FrameType.SETUP, 0); - - mutableDirectBuffer.putInt(offset + VERSION_FIELD_OFFSET, CURRENT_VERSION, ByteOrder.BIG_ENDIAN); - mutableDirectBuffer.putInt(offset + KEEPALIVE_INTERVAL_FIELD_OFFSET, keepaliveInterval, ByteOrder.BIG_ENDIAN); - mutableDirectBuffer.putInt(offset + MAX_LIFETIME_FIELD_OFFSET, maxLifetime, ByteOrder.BIG_ENDIAN); - - length += BitUtil.SIZE_OF_INT * 3; - - length += putMimeType(mutableDirectBuffer, offset + length, metadataMimeType); - length += putMimeType(mutableDirectBuffer, offset + length, dataMimeType); - - length += FrameHeaderFlyweight.encodeMetadata(mutableDirectBuffer, offset, offset + length, metadata); - length += FrameHeaderFlyweight.encodeData(mutableDirectBuffer, offset + length, data); - - return length; - } - - public static int version(final DirectBuffer directBuffer, final int offset) - { - return directBuffer.getInt(offset + VERSION_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static int keepaliveInterval(final DirectBuffer directBuffer, final int offset) - { - return directBuffer.getInt(offset + KEEPALIVE_INTERVAL_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static int maxLifetime(final DirectBuffer directBuffer, final int offset) - { - return directBuffer.getInt(offset + MAX_LIFETIME_FIELD_OFFSET, ByteOrder.BIG_ENDIAN); - } - - public static String metadataMimeType(final DirectBuffer directBuffer, final int offset) - { - final byte[] bytes = getMimeType(directBuffer, offset + METADATA_MIME_TYPE_LENGTH_OFFSET); - return new String(bytes, Charset.forName("UTF-8")); - } - - public static String dataMimeType(final DirectBuffer directBuffer, final int offset) - { - int fieldOffset = offset + METADATA_MIME_TYPE_LENGTH_OFFSET; - - fieldOffset += 1 + directBuffer.getByte(fieldOffset); - - final byte[] bytes = getMimeType(directBuffer, fieldOffset); - return new String(bytes, Charset.forName("UTF-8")); - } - - public static int computePayloadOffset( - final int offset, final int metadataMimeTypeLength, final int dataMimeTypeLength) - { - return offset + METADATA_MIME_TYPE_LENGTH_OFFSET + - 1 + metadataMimeTypeLength + - 1 + dataMimeTypeLength; - } - - public static int payloadOffset(final DirectBuffer directBuffer, final int offset) - { - int fieldOffset = offset + METADATA_MIME_TYPE_LENGTH_OFFSET; - - final int metadataMimeTypeLength = directBuffer.getByte(fieldOffset); - fieldOffset += 1 + metadataMimeTypeLength; - - final int dataMimeTypeLength = directBuffer.getByte(fieldOffset); - fieldOffset += 1 + dataMimeTypeLength; - - return fieldOffset; - } - - private static int putMimeType( - final MutableDirectBuffer mutableDirectBuffer, final int fieldOffset, final String mimeType) - { - mutableDirectBuffer.putByte(fieldOffset, (byte) mimeType.length()); - mutableDirectBuffer.putBytes(fieldOffset + 1, mimeType.getBytes()); - - return 1 + mimeType.length(); - } - - private static byte[] getMimeType(final DirectBuffer directBuffer, final int fieldOffset) - { - final int length = directBuffer.getByte(fieldOffset); - final byte[] bytes = new byte[length]; - - directBuffer.getBytes(fieldOffset + 1, bytes); - return bytes; - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/ThreadLocalFramePool.java b/src/main/java/io/reactivesocket/internal/frame/ThreadLocalFramePool.java deleted file mode 100644 index b7b6e1368..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/ThreadLocalFramePool.java +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.Frame; -import org.agrona.MutableDirectBuffer; -import org.agrona.concurrent.OneToOneConcurrentArrayQueue; -import org.agrona.concurrent.UnsafeBuffer; - -import java.nio.ByteBuffer; - -public class ThreadLocalFramePool implements FramePool -{ - private static final int MAX_CAHED_FRAMES_PER_THREAD = 16; - - private static final ThreadLocal> PER_THREAD_FRAME_QUEUE = - ThreadLocal.withInitial(() -> new OneToOneConcurrentArrayQueue<>(MAX_CAHED_FRAMES_PER_THREAD)); - - private static final ThreadLocal> PER_THREAD_DIRECTBUFFER_QUEUE = - ThreadLocal.withInitial(() -> new OneToOneConcurrentArrayQueue<>(MAX_CAHED_FRAMES_PER_THREAD)); - - public Frame acquireFrame(int size) - { - final MutableDirectBuffer directBuffer = acquireMutableDirectBuffer(size); - - Frame frame = pollFrame(); - if (null == frame) - { - frame = Frame.allocate(directBuffer); - } - - return frame; - } - - public Frame acquireFrame(ByteBuffer byteBuffer) - { - return Frame.allocate(new UnsafeBuffer(byteBuffer)); - } - - public void release(Frame frame) - { - PER_THREAD_FRAME_QUEUE.get().offer(frame); - } - - public Frame acquireFrame(MutableDirectBuffer mutableDirectBuffer) - { - Frame frame = pollFrame(); - if (null == frame) - { - frame = Frame.allocate(mutableDirectBuffer); - } - - return frame; - } - - public MutableDirectBuffer acquireMutableDirectBuffer(ByteBuffer byteBuffer) - { - MutableDirectBuffer directBuffer = pollMutableDirectBuffer(); - if (null == directBuffer) - { - directBuffer = new UnsafeBuffer(byteBuffer); - } - - return directBuffer; - } - - public MutableDirectBuffer acquireMutableDirectBuffer(int size) - { - UnsafeBuffer directBuffer = (UnsafeBuffer)pollMutableDirectBuffer(); - if (null == directBuffer || directBuffer.byteBuffer().capacity() < size) - { - directBuffer = new UnsafeBuffer(ByteBuffer.allocate(size)); - } - else - { - directBuffer.byteBuffer().limit(size).position(0); - } - - return directBuffer; - } - - public void release(MutableDirectBuffer mutableDirectBuffer) - { - PER_THREAD_DIRECTBUFFER_QUEUE.get().offer(mutableDirectBuffer); - } - - private Frame pollFrame() - { - return PER_THREAD_FRAME_QUEUE.get().poll(); - } - - private MutableDirectBuffer pollMutableDirectBuffer() - { - return PER_THREAD_DIRECTBUFFER_QUEUE.get().poll(); - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/ThreadSafeFramePool.java b/src/main/java/io/reactivesocket/internal/frame/ThreadSafeFramePool.java deleted file mode 100644 index bd0e00cbf..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/ThreadSafeFramePool.java +++ /dev/null @@ -1,129 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.Frame; -import org.agrona.MutableDirectBuffer; -import org.agrona.concurrent.OneToOneConcurrentArrayQueue; -import org.agrona.concurrent.UnsafeBuffer; - -import java.nio.ByteBuffer; - -public class ThreadSafeFramePool implements FramePool -{ - private static final int MAX_CACHED_FRAMES = 16; - - private final OneToOneConcurrentArrayQueue frameQueue; - private final OneToOneConcurrentArrayQueue directBufferQueue; - - public ThreadSafeFramePool() - { - this(MAX_CACHED_FRAMES, MAX_CACHED_FRAMES); - } - - public ThreadSafeFramePool(final int frameQueueLength, final int directBufferQueueLength) - { - frameQueue = new OneToOneConcurrentArrayQueue<>(frameQueueLength); - directBufferQueue = new OneToOneConcurrentArrayQueue<>(directBufferQueueLength); - } - - public Frame acquireFrame(int size) - { - final MutableDirectBuffer directBuffer = acquireMutableDirectBuffer(size); - - Frame frame = pollFrame(); - if (null == frame) - { - frame = Frame.allocate(directBuffer); - } - - return frame; - } - - public Frame acquireFrame(ByteBuffer byteBuffer) - { - return Frame.allocate(new UnsafeBuffer(byteBuffer)); - } - - public Frame acquireFrame(MutableDirectBuffer mutableDirectBuffer) - { - Frame frame = pollFrame(); - if (null == frame) - { - frame = Frame.allocate(mutableDirectBuffer); - } - - return frame; - } - - public MutableDirectBuffer acquireMutableDirectBuffer(ByteBuffer byteBuffer) - { - MutableDirectBuffer directBuffer = pollMutableDirectBuffer(); - if (null == directBuffer) - { - directBuffer = new UnsafeBuffer(byteBuffer); - } - - return directBuffer; - } - - public MutableDirectBuffer acquireMutableDirectBuffer(int size) - { - UnsafeBuffer directBuffer = (UnsafeBuffer)pollMutableDirectBuffer(); - if (null == directBuffer || directBuffer.capacity() < size) - { - directBuffer = new UnsafeBuffer(ByteBuffer.allocate(size)); - } - else - { - directBuffer.byteBuffer().limit(size).position(0); - } - - return directBuffer; - } - - public void release(Frame frame) - { - synchronized (frameQueue) - { - frameQueue.offer(frame); - } - } - - public void release(MutableDirectBuffer mutableDirectBuffer) - { - synchronized (directBufferQueue) - { - directBufferQueue.offer(mutableDirectBuffer); - } - } - - private Frame pollFrame() - { - synchronized (frameQueue) - { - return frameQueue.poll(); - } - } - - private MutableDirectBuffer pollMutableDirectBuffer() - { - synchronized (directBufferQueue) - { - return directBufferQueue.poll(); - } - } -} diff --git a/src/main/java/io/reactivesocket/internal/frame/UnpooledFrame.java b/src/main/java/io/reactivesocket/internal/frame/UnpooledFrame.java deleted file mode 100644 index 09f04288f..000000000 --- a/src/main/java/io/reactivesocket/internal/frame/UnpooledFrame.java +++ /dev/null @@ -1,65 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.frame; - -import io.reactivesocket.Frame; -import org.agrona.MutableDirectBuffer; -import org.agrona.concurrent.UnsafeBuffer; - -import java.nio.ByteBuffer; - -/** - * On demand creation for Frames, MutableDirectBuffer backed by ByteBuffers of required capacity - */ -public class UnpooledFrame implements FramePool -{ - /* - * TODO: have all gneration of UnsafeBuffer and ByteBuffer hidden behind acquire() calls (private for ByteBuffer) - */ - - public Frame acquireFrame(int size) - { - return Frame.allocate(new UnsafeBuffer(ByteBuffer.allocate(size))); - } - - public Frame acquireFrame(ByteBuffer byteBuffer) - { - return Frame.allocate(new UnsafeBuffer(byteBuffer)); - } - - public void release(Frame frame) - { - } - - public Frame acquireFrame(MutableDirectBuffer mutableDirectBuffer) - { - return Frame.allocate(mutableDirectBuffer); - } - - public MutableDirectBuffer acquireMutableDirectBuffer(ByteBuffer byteBuffer) - { - return new UnsafeBuffer(byteBuffer); - } - - public MutableDirectBuffer acquireMutableDirectBuffer(int size) - { - return new UnsafeBuffer(ByteBuffer.allocate(size)); - } - - public void release(MutableDirectBuffer mutableDirectBuffer) - { - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/AppendOnlyLinkedArrayList.java b/src/main/java/io/reactivesocket/internal/rx/AppendOnlyLinkedArrayList.java deleted file mode 100644 index 0b1ee24b7..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/AppendOnlyLinkedArrayList.java +++ /dev/null @@ -1,125 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -package io.reactivesocket.internal.rx; - -import java.util.function.*; - -/** - * A linked-array-list implementation that only supports appending and consumption. - * - * @param the value type - */ -public class AppendOnlyLinkedArrayList { - final int capacity; - Object[] head; - Object[] tail; - int offset; - - /** - * Constructs an empty list with a per-link capacity - * @param capacity the capacity of each link - */ - public AppendOnlyLinkedArrayList(int capacity) { - this.capacity = capacity; - this.head = new Object[capacity + 1]; - this.tail = head; - } - - /** - * Append a non-null value to the list. - *

Don't add null to the list! - * @param value the value to append - */ - public void add(T value) { - final int c = capacity; - int o = offset; - if (o == c) { - Object[] next = new Object[c + 1]; - tail[c] = next; - tail = next; - o = 0; - } - tail[o] = value; - offset = o + 1; - } - - /** - * Set a value as the first element of the list. - * @param value the value to set - */ - public void setFirst(T value) { - head[0] = value; - } - - /** - * Loops through all elements of the list. - * @param consumer the consumer of elements - */ - @SuppressWarnings("unchecked") - public void forEach(Consumer consumer) { - Object[] a = head; - final int c = capacity; - while (a != null) { - for (int i = 0; i < c; i++) { - Object o = a[i]; - if (o == null) { - return; - } - consumer.accept((T)o); - } - a = (Object[])a[c]; - } - } - - /** - * Loops over all elements of the array until a null element is encountered or - * the given predicate returns true. - * @param consumer the consumer of values that returns true if the forEach should terminate - */ - @SuppressWarnings("unchecked") - public void forEachWhile(Predicate consumer) { - Object[] a = head; - final int c = capacity; - while (a != null) { - for (int i = 0; i < c; i++) { - Object o = a[i]; - if (o == null) { - return; - } - if (consumer.test((T)o)) { - return; - } - } - a = (Object[])a[c]; - } - } - - @SuppressWarnings("unchecked") - public void forEachWhile(S state, BiPredicate consumer) { - Object[] a = head; - final int c = capacity; - while (a != null) { - for (int i = 0; i < c; i++) { - Object o = a[i]; - if (o == null) { - return; - } - if (consumer.test(state, (T)o)) { - return; - } - } - a = (Object[])a[c]; - } - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/BackpressureHelper.java b/src/main/java/io/reactivesocket/internal/rx/BackpressureHelper.java deleted file mode 100644 index cd51ac01b..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/BackpressureHelper.java +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ -package io.reactivesocket.internal.rx; - -import java.util.concurrent.atomic.*; - -/** - * Utility class to help with backpressure-related operations such as request aggregation. - */ -public enum BackpressureHelper { - ; - /** - * Adds two long values and caps the sum at Long.MAX_VALUE. - * @param a the first value - * @param b the second value - * @return the sum capped at Long.MAX_VALUE - */ - public static long addCap(long a, long b) { - long u = a + b; - if (u < 0L) { - return Long.MAX_VALUE; - } - return u; - } - - /** - * Multiplies two long values and caps the product at Long.MAX_VALUE. - * @param a the first value - * @param b the second value - * @return the product capped at Long.MAX_VALUE - */ - public static long multiplyCap(long a, long b) { - long u = a * b; - if (((a | b) >>> 31) != 0) { - if (u / a != b) { - return Long.MAX_VALUE; - } - } - return u; - } - - /** - * Atomically adds the positive value n to the requested value in the AtomicLong and - * caps the result at Long.MAX_VALUE and returns the previous value. - * @param requested the AtomicLong holding the current requested value - * @param n the value to add, must be positive (not verified) - * @return the original value before the add - */ - public static long add(AtomicLong requested, long n) { - for (;;) { - long r = requested.get(); - if (r == Long.MAX_VALUE) { - return Long.MAX_VALUE; - } - long u = addCap(r, n); - if (requested.compareAndSet(r, u)) { - return r; - } - } - } - - /** - * Atomically adds the positive value n to the value in the instance through the field updater and - * caps the result at Long.MAX_VALUE and returns the previous value. - * @param updater the field updater for the requested value - * @param instance the instance holding the requested value - * @param n the value to add, must be positive (not verified) - * @return the original value before the add - */ - public static long add(AtomicLongFieldUpdater updater, T instance, long n) { - for (;;) { - long r = updater.get(instance); - if (r == Long.MAX_VALUE) { - return Long.MAX_VALUE; - } - long u = addCap(r, n); - if (updater.compareAndSet(instance, r, u)) { - return r; - } - } - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/BackpressureUtils.java b/src/main/java/io/reactivesocket/internal/rx/BackpressureUtils.java deleted file mode 100644 index d67598d1f..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/BackpressureUtils.java +++ /dev/null @@ -1,107 +0,0 @@ -package io.reactivesocket.internal.rx; - -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not - * use this file except in compliance with the License. You may obtain a copy of - * the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations under - * the License. - */ -// Copied from RxJava: https://github.com/ReactiveX/RxJava/blob/1.x/src/main/java/rx/internal/operators/BackpressureUtils.java - -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicLongFieldUpdater; - -/** - * Utility functions for use with backpressure. - */ -public final class BackpressureUtils { - /** Utility class, no instances. */ - private BackpressureUtils() { - throw new IllegalStateException("No instances!"); - } - /** - * Adds {@code n} to {@code requested} field and returns the value prior to - * addition once the addition is successful (uses CAS semantics). If - * overflows then sets {@code requested} field to {@code Long.MAX_VALUE}. - * - * @param requested - * atomic field updater for a request count - * @param object - * contains the field updated by the updater - * @param n - * the number of requests to add to the requested count - * @return requested value just prior to successful addition - */ - public static long getAndAddRequest(AtomicLongFieldUpdater requested, T object, long n) { - // add n to field but check for overflow - while (true) { - long current = requested.get(object); - long next = addCap(current, n); - if (requested.compareAndSet(object, current, next)) { - return current; - } - } - } - - /** - * Adds {@code n} to {@code requested} and returns the value prior to addition once the - * addition is successful (uses CAS semantics). If overflows then sets - * {@code requested} field to {@code Long.MAX_VALUE}. - * - * @param requested - * atomic long that should be updated - * @param n - * the number of requests to add to the requested count - * @return requested value just prior to successful addition - */ - public static long getAndAddRequest(AtomicLong requested, long n) { - // add n to field but check for overflow - while (true) { - long current = requested.get(); - long next = addCap(current, n); - if (requested.compareAndSet(current, next)) { - return current; - } - } - } - - /** - * Multiplies two positive longs and caps the result at Long.MAX_VALUE. - * @param a the first value - * @param b the second value - * @return the capped product of a and b - */ - public static long multiplyCap(long a, long b) { - long u = a * b; - if (((a | b) >>> 31) != 0) { - if (b != 0L && (u / b != a)) { - u = Long.MAX_VALUE; - } - } - return u; - } - - /** - * Adds two positive longs and caps the result at Long.MAX_VALUE. - * @param a the first value - * @param b the second value - * @return the capped sum of a and b - */ - public static long addCap(long a, long b) { - long u = a + b; - if (u < 0L) { - u = Long.MAX_VALUE; - } - return u; - } - -} \ No newline at end of file diff --git a/src/main/java/io/reactivesocket/internal/rx/BaseArrayQueue.java b/src/main/java/io/reactivesocket/internal/rx/BaseArrayQueue.java deleted file mode 100644 index 214aa00b2..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/BaseArrayQueue.java +++ /dev/null @@ -1,131 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -/* - * The code was inspired by the similarly named JCTools class: - * https://github.com/JCTools/JCTools/blob/master/jctools-core/src/main/java/org/jctools/queues/atomic - */ - -package io.reactivesocket.internal.rx; - -import java.util.*; -import java.util.concurrent.atomic.AtomicReferenceArray; - -abstract class BaseArrayQueue extends AtomicReferenceArray implements Queue { - /** */ - private static final long serialVersionUID = 5238363267841964068L; - protected final int mask; - public BaseArrayQueue(int capacity) { - super(Pow2.roundToPowerOfTwo(capacity)); - this.mask = length() - 1; - } - @Override - public Iterator iterator() { - throw new UnsupportedOperationException(); - } - @Override - public void clear() { - // we have to test isEmpty because of the weaker poll() guarantee - while (poll() != null || !isEmpty()) - ; - } - protected final int calcElementOffset(long index, int mask) { - return (int)index & mask; - } - protected final int calcElementOffset(long index) { - return (int)index & mask; - } - protected final E lvElement(AtomicReferenceArray buffer, int offset) { - return buffer.get(offset); - } - protected final E lpElement(AtomicReferenceArray buffer, int offset) { - return buffer.get(offset); // no weaker form available - } - protected final E lpElement(int offset) { - return get(offset); // no weaker form available - } - protected final void spElement(AtomicReferenceArray buffer, int offset, E value) { - buffer.lazySet(offset, value); // no weaker form available - } - protected final void spElement(int offset, E value) { - lazySet(offset, value); // no weaker form available - } - protected final void soElement(AtomicReferenceArray buffer, int offset, E value) { - buffer.lazySet(offset, value); - } - protected final void soElement(int offset, E value) { - lazySet(offset, value); - } - protected final void svElement(AtomicReferenceArray buffer, int offset, E value) { - buffer.set(offset, value); - } - protected final E lvElement(int offset) { - return get(offset); - } - - @Override - public boolean add(E e) { - throw new UnsupportedOperationException(); - } - - @Override - public E remove() { - throw new UnsupportedOperationException(); - } - - @Override - public E element() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean contains(Object o) { - throw new UnsupportedOperationException(); - } - - @Override - public Object[] toArray() { - throw new UnsupportedOperationException(); - } - - @Override - public T[] toArray(T[] a) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean remove(Object o) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean containsAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean addAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean removeAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean retainAll(Collection c) { - throw new UnsupportedOperationException(); - } -} - diff --git a/src/main/java/io/reactivesocket/internal/rx/BaseLinkedQueue.java b/src/main/java/io/reactivesocket/internal/rx/BaseLinkedQueue.java deleted file mode 100644 index bc1047ae2..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/BaseLinkedQueue.java +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -/* - * The code was inspired by the similarly named JCTools class: - * https://github.com/JCTools/JCTools/blob/master/jctools-core/src/main/java/org/jctools/queues/atomic - */ - -package io.reactivesocket.internal.rx; - -import java.util.*; -import java.util.concurrent.atomic.AtomicReference; - -abstract class BaseLinkedQueue extends AbstractQueue { - private final AtomicReference> producerNode; - private final AtomicReference> consumerNode; - public BaseLinkedQueue() { - producerNode = new AtomicReference<>(); - consumerNode = new AtomicReference<>(); - } - protected final LinkedQueueNode lvProducerNode() { - return producerNode.get(); - } - protected final LinkedQueueNode lpProducerNode() { - return producerNode.get(); - } - protected final void spProducerNode(LinkedQueueNode node) { - producerNode.lazySet(node); - } - protected final LinkedQueueNode xchgProducerNode(LinkedQueueNode node) { - return producerNode.getAndSet(node); - } - protected final LinkedQueueNode lvConsumerNode() { - return consumerNode.get(); - } - - protected final LinkedQueueNode lpConsumerNode() { - return consumerNode.get(); - } - protected final void spConsumerNode(LinkedQueueNode node) { - consumerNode.lazySet(node); - } - @Override - public final Iterator iterator() { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc}
- *

- * IMPLEMENTATION NOTES:
- * This is an O(n) operation as we run through all the nodes and count them.
- * - * @see java.util.Queue#size() - */ - @Override - public final int size() { - LinkedQueueNode chaserNode = lvConsumerNode(); - final LinkedQueueNode producerNode = lvProducerNode(); - int size = 0; - // must chase the nodes all the way to the producer node, but there's no need to chase a moving target. - while (chaserNode != producerNode && size < Integer.MAX_VALUE) { - LinkedQueueNode next; - while((next = chaserNode.lvNext()) == null); - chaserNode = next; - size++; - } - return size; - } - /** - * {@inheritDoc}
- *

- * IMPLEMENTATION NOTES:
- * Queue is empty when producerNode is the same as consumerNode. An alternative implementation would be to observe - * the producerNode.value is null, which also means an empty queue because only the consumerNode.value is allowed to - * be null. - * - * @see MessagePassingQueue#isEmpty() - */ - @Override - public final boolean isEmpty() { - return lvConsumerNode() == lvProducerNode(); - } -} \ No newline at end of file diff --git a/src/main/java/io/reactivesocket/internal/rx/BooleanDisposable.java b/src/main/java/io/reactivesocket/internal/rx/BooleanDisposable.java deleted file mode 100644 index 6e4b70d27..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/BooleanDisposable.java +++ /dev/null @@ -1,37 +0,0 @@ -package io.reactivesocket.internal.rx; - -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; - -import io.reactivesocket.rx.Disposable; - -public final class BooleanDisposable implements Disposable { - volatile Runnable run; - - static final AtomicReferenceFieldUpdater RUN = - AtomicReferenceFieldUpdater.newUpdater(BooleanDisposable.class, Runnable.class, "run"); - - static final Runnable DISPOSED = () -> { }; - - public BooleanDisposable() { - this(() -> { }); - } - - public BooleanDisposable(Runnable run) { - RUN.lazySet(this, run); - } - - @Override - public void dispose() { - Runnable r = run; - if (r != DISPOSED) { - r = RUN.getAndSet(this, DISPOSED); - if (r != DISPOSED) { - r.run(); - } - } - } - - public boolean isDisposed() { - return run == DISPOSED; - } -} \ No newline at end of file diff --git a/src/main/java/io/reactivesocket/internal/rx/CompositeCompletable.java b/src/main/java/io/reactivesocket/internal/rx/CompositeCompletable.java deleted file mode 100644 index 7d4f4a020..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/CompositeCompletable.java +++ /dev/null @@ -1,85 +0,0 @@ -package io.reactivesocket.internal.rx; - -import java.util.HashSet; -import java.util.Set; - -import io.reactivesocket.rx.Completable; - -/** - * A Completable container that can hold onto multiple other Completables. - */ -public final class CompositeCompletable implements Completable { - - // protected by synchronized - private boolean completed = false; - private Throwable error = null; - final Set resources = new HashSet<>(); - - public CompositeCompletable() { - - } - - public void add(Completable d) { - boolean terminal = false; - synchronized (this) { - if (error != null || completed) { - terminal = true; - } else { - resources.add(d); - } - } - if (terminal) { - if (error != null) { - d.error(error); - } else { - d.success(); - } - } - } - - public void remove(Completable d) { - synchronized (this) { - resources.remove(d); - } - } - - public void clear() { - synchronized (this) { - resources.clear(); - } - } - - @Override - public void success() { - Completable[] cs = null; - synchronized (this) { - if (error == null) { - completed = true; - cs = resources.toArray(new Completable[] {}); - resources.clear(); - } - } - if (cs != null) { - for (Completable c : cs) { - c.success(); - } - } - } - - @Override - public void error(Throwable e) { - Completable[] cs = null; - synchronized (this) { - if (error == null && !completed) { - error = e; - cs = resources.toArray(new Completable[] {}); - resources.clear(); - } - } - if (cs != null) { - for (Completable c : cs) { - c.error(e); - } - } - } -} \ No newline at end of file diff --git a/src/main/java/io/reactivesocket/internal/rx/CompositeDisposable.java b/src/main/java/io/reactivesocket/internal/rx/CompositeDisposable.java deleted file mode 100644 index f46a65901..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/CompositeDisposable.java +++ /dev/null @@ -1,61 +0,0 @@ -package io.reactivesocket.internal.rx; - -import java.util.HashSet; -import java.util.Set; - -import io.reactivesocket.rx.Completable; -import io.reactivesocket.rx.Disposable; - -/** - * A Disposable container that can hold onto multiple other Disposables. - */ -public final class CompositeDisposable implements Disposable { - - // protected by synchronized - private boolean disposed = false; - final Set resources = new HashSet<>(); - - public CompositeDisposable() { - - } - - public void add(Disposable d) { - boolean isDisposed = false; - synchronized (this) { - if (disposed) { - isDisposed = true; - } else { - resources.add(d); - } - } - if (isDisposed) { - d.dispose(); - } - } - - public void remove(Completable d) { - synchronized (this) { - resources.remove(d); - } - } - - public void clear() { - synchronized (this) { - resources.clear(); - } - } - - @Override - public void dispose() { - Disposable[] cs; - synchronized (this) { - disposed = true; - cs = resources.toArray(new Disposable[] {}); - resources.clear(); - } - for (Disposable d : cs) { - d.dispose(); - } - } - -} \ No newline at end of file diff --git a/src/main/java/io/reactivesocket/internal/rx/EmptyDisposable.java b/src/main/java/io/reactivesocket/internal/rx/EmptyDisposable.java deleted file mode 100644 index f69d4662e..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/EmptyDisposable.java +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal.rx; - -import io.reactivesocket.rx.Disposable; - -public class EmptyDisposable implements Disposable -{ - public static final EmptyDisposable EMPTY = new EmptyDisposable(); - - public void dispose() - { - } - - public boolean isDisposed() { - return false; - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/EmptySubscription.java b/src/main/java/io/reactivesocket/internal/rx/EmptySubscription.java deleted file mode 100644 index fe2c38685..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/EmptySubscription.java +++ /dev/null @@ -1,67 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -package io.reactivesocket.internal.rx; - -import org.reactivestreams.*; - -/** - * An empty subscription that does nothing other than validates the request amount. - */ -public enum EmptySubscription implements Subscription { - /** A singleton, stateless instance. */ - INSTANCE; - - @Override - public void request(long n) { - SubscriptionHelper.validateRequest(n); - } - @Override - public void cancel() { - // no-op - } - - @Override - public String toString() { - return "EmptySubscription"; - } - - /** - * Sets the empty subscription instance on the subscriber and then - * calls onError with the supplied error. - * - *

Make sure this is only called if the subscriber hasn't received a - * subscription already (there is no way of telling this). - * - * @param e the error to deliver to the subscriber - * @param s the target subscriber - */ - public static void error(Throwable e, Subscriber s) { - s.onSubscribe(INSTANCE); - s.onError(e); - } - - /** - * Sets the empty subscription instance on the subscriber and then - * calls onComplete. - * - *

Make sure this is only called if the subscriber hasn't received a - * subscription already (there is no way of telling this). - * - * @param s the target subscriber - */ - public static void complete(Subscriber s) { - s.onSubscribe(INSTANCE); - s.onComplete(); - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/LinkedQueueNode.java b/src/main/java/io/reactivesocket/internal/rx/LinkedQueueNode.java deleted file mode 100644 index bc03d6f0c..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/LinkedQueueNode.java +++ /dev/null @@ -1,58 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -/* - * The code was inspired by the similarly named JCTools class: - * https://github.com/JCTools/JCTools/blob/master/jctools-core/src/main/java/org/jctools/queues/atomic - */ - -package io.reactivesocket.internal.rx; - -import java.util.concurrent.atomic.AtomicReference; - -public final class LinkedQueueNode extends AtomicReference> { - /** */ - private static final long serialVersionUID = 2404266111789071508L; - private E value; - LinkedQueueNode() { - } - LinkedQueueNode(E val) { - spValue(val); - } - /** - * Gets the current value and nulls out the reference to it from this node. - * - * @return value - */ - public E getAndNullValue() { - E temp = lpValue(); - spValue(null); - return temp; - } - - public E lpValue() { - return value; - } - - public void spValue(E newValue) { - value = newValue; - } - - public void soNext(LinkedQueueNode n) { - lazySet(n); - } - - public LinkedQueueNode lvNext() { - return get(); - } -} \ No newline at end of file diff --git a/src/main/java/io/reactivesocket/internal/rx/MpscLinkedQueue.java b/src/main/java/io/reactivesocket/internal/rx/MpscLinkedQueue.java deleted file mode 100644 index 45741ac38..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/MpscLinkedQueue.java +++ /dev/null @@ -1,112 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -/* - * The code was inspired by the similarly named JCTools class: - * https://github.com/JCTools/JCTools/blob/master/jctools-core/src/main/java/org/jctools/queues/atomic - */ - -package io.reactivesocket.internal.rx; - -/** - * A multi-producer single consumer unbounded queue. - */ -public final class MpscLinkedQueue extends BaseLinkedQueue { - - public MpscLinkedQueue() { - super(); - LinkedQueueNode node = new LinkedQueueNode<>(); - spConsumerNode(node); - xchgProducerNode(node);// this ensures correct construction: StoreLoad - } - /** - * {@inheritDoc}
- *

- * IMPLEMENTATION NOTES:
- * Offer is allowed from multiple threads.
- * Offer allocates a new node and: - *

    - *
  1. Swaps it atomically with current producer node (only one producer 'wins') - *
  2. Sets the new node as the node following from the swapped producer node - *
- * This works because each producer is guaranteed to 'plant' a new node and link the old node. No 2 producers can - * get the same producer node as part of XCHG guarantee. - * - * @see MessagePassingQueue#offer(Object) - * @see java.util.Queue#offer(java.lang.Object) - */ - @Override - public final boolean offer(final T nextValue) { - final LinkedQueueNode nextNode = new LinkedQueueNode<>(nextValue); - final LinkedQueueNode prevProducerNode = xchgProducerNode(nextNode); - // Should a producer thread get interrupted here the chain WILL be broken until that thread is resumed - // and completes the store in prev.next. - prevProducerNode.soNext(nextNode); // StoreStore - return true; - } - - /** - * {@inheritDoc}
- *

- * IMPLEMENTATION NOTES:
- * Poll is allowed from a SINGLE thread.
- * Poll reads the next node from the consumerNode and: - *

    - *
  1. If it is null, the queue is assumed empty (though it might not be). - *
  2. If it is not null set it as the consumer node and return it's now evacuated value. - *
- * This means the consumerNode.value is always null, which is also the starting point for the queue. Because null - * values are not allowed to be offered this is the only node with it's value set to null at any one time. - * - * @see MessagePassingQueue#poll() - * @see java.util.Queue#poll() - */ - @Override - public final T poll() { - LinkedQueueNode currConsumerNode = lpConsumerNode(); // don't load twice, it's alright - LinkedQueueNode nextNode = currConsumerNode.lvNext(); - if (nextNode != null) { - // we have to null out the value because we are going to hang on to the node - final T nextValue = nextNode.getAndNullValue(); - spConsumerNode(nextNode); - return nextValue; - } - else if (currConsumerNode != lvProducerNode()) { - // spin, we are no longer wait free - while((nextNode = currConsumerNode.lvNext()) == null); - // got the next node... - - // we have to null out the value because we are going to hang on to the node - final T nextValue = nextNode.getAndNullValue(); - spConsumerNode(nextNode); - return nextValue; - } - return null; - } - - @Override - public final T peek() { - LinkedQueueNode currConsumerNode = lpConsumerNode(); // don't load twice, it's alright - LinkedQueueNode nextNode = currConsumerNode.lvNext(); - if (nextNode != null) { - return nextNode.lpValue(); - } else - if (currConsumerNode != lvProducerNode()) { - // spin, we are no longer wait free - while ((nextNode = currConsumerNode.lvNext()) == null); - // got the next node... - return nextNode.lpValue(); - } - return null; - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/NotificationLite.java b/src/main/java/io/reactivesocket/internal/rx/NotificationLite.java deleted file mode 100644 index 2091ed836..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/NotificationLite.java +++ /dev/null @@ -1,207 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ -package io.reactivesocket.internal.rx; - -import java.io.Serializable; - -import org.reactivestreams.*; - -/** - * Lightweight notification handling utility class. - */ -public enum NotificationLite { - // No instances - ; - - /** - * Indicates a completion notification. - */ - private enum Complete { - INSTANCE; - @Override - public String toString() { - return "NotificationLite.Complete"; - }; - } - - /** - * Wraps a Throwable. - */ - private static final class ErrorNotification implements Serializable { - /** */ - private static final long serialVersionUID = -8759979445933046293L; - final Throwable e; - ErrorNotification(Throwable e) { - this.e = e; - } - - @Override - public String toString() { - return "NotificationLite.Error[" + e + "]"; - } - } - - /** - * Wraps a Subscription. - */ - private static final class SubscriptionNotification implements Serializable { - /** */ - private static final long serialVersionUID = -1322257508628817540L; - final Subscription s; - SubscriptionNotification(Subscription s) { - this.s = s; - } - - @Override - public String toString() { - return "NotificationLite.Subscription[" + s + "]"; - } - } - - /** - * Converts a value into a notification value. - * @param value the value to convert - * @return the notification representing the value - */ - public static Object next(T value) { - return value; - } - - /** - * Returns a complete notification. - * @return a complete notification - */ - public static Object complete() { - return Complete.INSTANCE; - } - - /** - * Converts a Throwable into a notification value. - * @param e the Throwable to convert - * @return the notification representing the Throwable - */ - public static Object error(Throwable e) { - return new ErrorNotification(e); - } - - /** - * Converts a Subscription into a notification value. - * @param e the Subscription to convert - * @return the notification representing the Subscription - */ - public static Object subscription(Subscription s) { - return new SubscriptionNotification(s); - } - - /** - * Checks if the given object represents a complete notification. - * @param o the object to check - * @return true if the object represents a complete notification - */ - public static boolean isComplete(Object o) { - return o == Complete.INSTANCE; - } - - /** - * Checks if the given object represents a error notification. - * @param o the object to check - * @return true if the object represents a error notification - */ - public static boolean isError(Object o) { - return o instanceof ErrorNotification; - } - - /** - * Checks if the given object represents a subscription notification. - * @param o the object to check - * @return true if the object represents a subscription notification - */ - public static boolean isSubscription(Object o) { - return o instanceof SubscriptionNotification; - } - - /** - * Extracts the value from the notification object - * @param o the notification object - * @return the extracted value - */ - @SuppressWarnings("unchecked") - public static T getValue(Object o) { - return (T)o; - } - - /** - * Extracts the Throwable from the notification object - * @param o the notification object - * @return the extracted Throwable - */ - public static Throwable getError(Object o) { - return ((ErrorNotification)o).e; - } - - /** - * Extracts the Subscription from the notification object - * @param o the notification object - * @return the extracted Subscription - */ - public static Subscription getSubscription(Object o) { - return ((SubscriptionNotification)o).s; - } - - /** - * Calls the appropriate Subscriber method based on the type of the notification. - *

Does not check for a subscription notification, see {@link #acceptFull(Object, Subscriber)}. - * @param o the notification object - * @param s the subscriber to call methods on - * @return true if the notification was a terminal event (i.e., complete or error) - * @see #acceptFull(Object, Subscriber) - */ - @SuppressWarnings("unchecked") - public static boolean accept(Object o, Subscriber s) { - if (o == Complete.INSTANCE) { - s.onComplete(); - return true; - } else - if (o instanceof ErrorNotification) { - s.onError(((ErrorNotification)o).e); - return true; - } - s.onNext((T)o); - return false; - } - - /** - * Calls the appropriate Subscriber method based on the type of the notification. - * @param o the notification object - * @param s the subscriber to call methods on - * @return true if the notification was a terminal event (i.e., complete or error) - * @see #accept(Object, Subscriber) - */ - @SuppressWarnings("unchecked") - public static boolean acceptFull(Object o, Subscriber s) { - if (o == Complete.INSTANCE) { - s.onComplete(); - return true; - } else - if (o instanceof ErrorNotification) { - s.onError(((ErrorNotification)o).e); - return true; - } else - if (o instanceof SubscriptionNotification) { - s.onSubscribe(((SubscriptionNotification)o).s); - return false; - } - s.onNext((T)o); - return false; - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/OperatorConcatMap.java b/src/main/java/io/reactivesocket/internal/rx/OperatorConcatMap.java deleted file mode 100644 index 641d7cf2b..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/OperatorConcatMap.java +++ /dev/null @@ -1,202 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ -package io.reactivesocket.internal.rx; - -import java.util.Queue; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; - -import org.reactivestreams.*; - -public final class OperatorConcatMap { - final Function> mapper; - final int bufferSize; - public OperatorConcatMap(Function> mapper, int bufferSize) { - this.mapper = mapper; - this.bufferSize = bufferSize; - } - - public Subscriber apply(Subscriber s) { - SerializedSubscriber ssub = new SerializedSubscriber<>(s); - SubscriptionArbiter sa = new SubscriptionArbiter(); - ssub.onSubscribe(sa); - return new SourceSubscriber<>(ssub, sa, mapper, bufferSize); - } - - static final class SourceSubscriber extends AtomicInteger implements Subscriber { - /** */ - private static final long serialVersionUID = 8828587559905699186L; - final Subscriber actual; - final SubscriptionArbiter sa; - final Function> mapper; - final Subscriber inner; - final Queue queue; - final int bufferSize; - - Subscription s; - - volatile boolean done; - - volatile long index; - - public SourceSubscriber(Subscriber actual, SubscriptionArbiter sa, - Function> mapper, int bufferSize) { - this.actual = actual; - this.sa = sa; - this.mapper = mapper; - this.bufferSize = bufferSize; - this.inner = new InnerSubscriber<>(actual, sa, this); - Queue q; - if (Pow2.isPowerOfTwo(bufferSize)) { - q = new SpscArrayQueue<>(bufferSize); - } else { - q = new SpscExactArrayQueue<>(bufferSize); - } - this.queue = q; - } - @Override - public void onSubscribe(Subscription s) { - if (this.s != null) { - s.cancel(); - return; - } - this.s = s; - s.request(bufferSize); - } - @Override - public void onNext(T t) { - if (done) { - return; - } - if (!queue.offer(t)) { - cancel(); - actual.onError(new IllegalStateException("More values received than requested!")); - return; - } - if (getAndIncrement() == 0) { - drain(); - } - } - @Override - public void onError(Throwable t) { - if (done) { - return; - } - done = true; - cancel(); - actual.onError(t); - } - @Override - public void onComplete() { - if (done) { - return; - } - done = true; - if (getAndIncrement() == 0) { - drain(); - } - } - - void innerComplete() { - if (decrementAndGet() != 0) { - drain(); - } - if (!done) { - s.request(1); - } - } - - void cancel() { - sa.cancel(); - s.cancel(); - } - - void drain() { - boolean d = done; - T o = queue.poll(); - - if (o == null) { - if (d) { - actual.onComplete(); - return; - } - return; - } - Publisher p; - try { - p = mapper.apply(o); - } catch (Throwable e) { - cancel(); - actual.onError(e); - return; - } - index++; - // this is not RS but since our Subscriber doesn't hold state by itself, - // subscribing it to each source is safe and saves allocation - p.subscribe(inner); - } - } - - static final class InnerSubscriber implements Subscriber { - final Subscriber actual; - final SubscriptionArbiter sa; - final SourceSubscriber parent; - - /* - * FIXME this is a workaround for now, but doesn't work - * for async non-conforming sources. - * Such sources require individual instances of InnerSubscriber and a - * done field. - */ - - long index; - - public InnerSubscriber(Subscriber actual, - SubscriptionArbiter sa, SourceSubscriber parent) { - this.actual = actual; - this.sa = sa; - this.parent = parent; - this.index = 1; - } - - @Override - public void onSubscribe(Subscription s) { - if (index == parent.index) { - sa.setSubscription(s); - } - } - - @Override - public void onNext(U t) { - if (index == parent.index) { - actual.onNext(t); - sa.produced(1L); - } - } - @Override - public void onError(Throwable t) { - if (index == parent.index) { - index++; - parent.cancel(); - actual.onError(t); - } - } - @Override - public void onComplete() { - if (index == parent.index) { - index++; - parent.innerComplete(); - } - } - } -} \ No newline at end of file diff --git a/src/main/java/io/reactivesocket/internal/rx/Pow2.java b/src/main/java/io/reactivesocket/internal/rx/Pow2.java deleted file mode 100644 index 332144a26..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/Pow2.java +++ /dev/null @@ -1,46 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - - -/* - * Original License: https://github.com/JCTools/JCTools/blob/master/LICENSE - * Original location: https://github.com/JCTools/JCTools/blob/master/jctools-core/src/main/java/org/jctools/util/Pow2.java - */ -package io.reactivesocket.internal.rx; - -public final class Pow2 { - private Pow2() { - throw new IllegalStateException("No instances!"); - } - - /** - * Find the next larger positive power of two value up from the given value. If value is a power of two then - * this value will be returned. - * - * @param value from which next positive power of two will be found. - * @return the next positive power of 2 or this value if it is a power of 2. - */ - public static int roundToPowerOfTwo(final int value) { - return 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); - } - - /** - * Is this value a power of two. - * - * @param value to be tested to see if it is a power of two. - * @return true if the value is a power of 2 otherwise false. - */ - public static boolean isPowerOfTwo(final int value) { - return (value & (value - 1)) == 0; - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/QueueDrainHelper.java b/src/main/java/io/reactivesocket/internal/rx/QueueDrainHelper.java deleted file mode 100644 index fbafaff75..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/QueueDrainHelper.java +++ /dev/null @@ -1,280 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ -package io.reactivesocket.internal.rx; - -import java.util.concurrent.atomic.*; -import java.util.function.BooleanSupplier; - -/** - * Utility class to help with the queue-drain serialization idiom. - */ -public enum QueueDrainHelper { - ; - - /** - * A fast-path queue-drain serialization logic. - *

The decrementing of the state is left to the drain callback. - * @param updater - * @param instance - * @param fastPath called if the instance is uncontended. - * @param queue called if the instance is contended to queue up work - * @param drain called if the instance transitions to the drain state successfully - */ - public static void queueDrain(AtomicIntegerFieldUpdater updater, T instance, - Runnable fastPath, Runnable queue, Runnable drain) { - if (updater.get(instance) == 0 && updater.compareAndSet(instance, 0, 1)) { - fastPath.run(); - if (updater.decrementAndGet(instance) == 0) { - return; - } - } else { - queue.run(); - if (updater.getAndIncrement(instance) != 0) { - return; - } - } - drain.run(); - } - - /** - * A fast-path queue-drain serialization logic with the ability to leave the state - * in fastpath/drain mode or not continue after the call to queue. - *

The decrementing of the state is left to the drain callback. - * @param updater - * @param instance - * @param fastPath - * @param queue - * @param drain - */ - public static void queueDrainIf(AtomicIntegerFieldUpdater updater, T instance, - BooleanSupplier fastPath, BooleanSupplier queue, Runnable drain) { - if (updater.get(instance) == 0 && updater.compareAndSet(instance, 0, 1)) { - if (fastPath.getAsBoolean()) { - return; - } - if (updater.decrementAndGet(instance) == 0) { - return; - } - } else { - if (queue.getAsBoolean()) { - return; - } - if (updater.getAndIncrement(instance) != 0) { - return; - } - } - drain.run(); - } - - /** - * A fast-path queue-drain serialization logic where the drain is looped until - * the instance state reaches 0 again. - * @param updater - * @param instance - * @param fastPath - * @param queue - * @param drain - */ - public static void queueDrainLoop(AtomicIntegerFieldUpdater updater, T instance, - Runnable fastPath, Runnable queue, Runnable drain) { - if (updater.get(instance) == 0 && updater.compareAndSet(instance, 0, 1)) { - fastPath.run(); - if (updater.decrementAndGet(instance) == 0) { - return; - } - } else { - queue.run(); - if (updater.getAndIncrement(instance) != 0) { - return; - } - } - int missed = 1; - for (;;) { - drain.run(); - - missed = updater.addAndGet(instance, -missed); - if (missed == 0) { - return; - } - } - } - - /** - * A fast-path queue-drain serialization logic with looped drain call and the ability to leave the state - * in fastpath/drain mode or not continue after the call to queue. - * @param updater - * @param instance - * @param fastPath - * @param queue - * @param drain - */ - public static void queueDrainLoopIf(AtomicIntegerFieldUpdater updater, T instance, - BooleanSupplier fastPath, BooleanSupplier queue, BooleanSupplier drain) { - if (updater.get(instance) == 0 && updater.compareAndSet(instance, 0, 1)) { - if (fastPath.getAsBoolean()) { - return; - } - if (updater.decrementAndGet(instance) == 0) { - return; - } - } else { - if (queue.getAsBoolean()) { - return; - } - if (updater.getAndIncrement(instance) != 0) { - return; - } - } - int missed = 1; - for (;;) { - - if (drain.getAsBoolean()) { - return; - } - - missed = updater.addAndGet(instance, -missed); - if (missed == 0) { - return; - } - } - } - - /** - * A fast-path queue-drain serialization logic. - *

The decrementing of the state is left to the drain callback. - * @param updater - * @param instance - * @param fastPath called if the instance is uncontended. - * @param queue called if the instance is contended to queue up work - * @param drain called if the instance transitions to the drain state successfully - */ - public static void queueDrain(AtomicInteger instance, - Runnable fastPath, Runnable queue, Runnable drain) { - if (instance.get() == 0 && instance.compareAndSet(0, 1)) { - fastPath.run(); - if (instance.decrementAndGet() == 0) { - return; - } - } else { - queue.run(); - if (instance.getAndIncrement() != 0) { - return; - } - } - drain.run(); - } - - /** - * A fast-path queue-drain serialization logic with the ability to leave the state - * in fastpath/drain mode or not continue after the call to queue. - *

The decrementing of the state is left to the drain callback. - * @param updater - * @param instance - * @param fastPath - * @param queue - * @param drain - */ - public static void queueDrainIf(AtomicInteger instance, - BooleanSupplier fastPath, BooleanSupplier queue, Runnable drain) { - if (instance.get() == 0 && instance.compareAndSet(0, 1)) { - if (fastPath.getAsBoolean()) { - return; - } - if (instance.decrementAndGet() == 0) { - return; - } - } else { - if (queue.getAsBoolean()) { - return; - } - if (instance.getAndIncrement() != 0) { - return; - } - } - drain.run(); - } - - /** - * A fast-path queue-drain serialization logic where the drain is looped until - * the instance state reaches 0 again. - * @param updater - * @param instance - * @param fastPath - * @param queue - * @param drain - */ - public static void queueDrainLoop(AtomicInteger instance, - Runnable fastPath, Runnable queue, Runnable drain) { - if (instance.get() == 0 && instance.compareAndSet(0, 1)) { - fastPath.run(); - if (instance.decrementAndGet() == 0) { - return; - } - } else { - queue.run(); - if (instance.getAndIncrement() != 0) { - return; - } - } - int missed = 1; - for (;;) { - drain.run(); - - missed = instance.addAndGet(-missed); - if (missed == 0) { - return; - } - } - } - - /** - * A fast-path queue-drain serialization logic with looped drain call and the ability to leave the state - * in fastpath/drain mode or not continue after the call to queue. - * @param updater - * @param instance - * @param fastPath - * @param queue - * @param drain - */ - public static void queueDrainLoopIf(AtomicInteger instance, - BooleanSupplier fastPath, BooleanSupplier queue, BooleanSupplier drain) { - if (instance.get() == 0 && instance.compareAndSet(0, 1)) { - if (fastPath.getAsBoolean()) { - return; - } - if (instance.decrementAndGet() == 0) { - return; - } - } else { - if (queue.getAsBoolean()) { - return; - } - if (instance.getAndIncrement() != 0) { - return; - } - } - int missed = 1; - for (;;) { - - if (drain.getAsBoolean()) { - return; - } - - missed = instance.addAndGet(-missed); - if (missed == 0) { - return; - } - } - } - -} diff --git a/src/main/java/io/reactivesocket/internal/rx/README.md b/src/main/java/io/reactivesocket/internal/rx/README.md deleted file mode 100644 index dc1b56023..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/README.md +++ /dev/null @@ -1,3 +0,0 @@ -RxJava v2 code copy/pasted to here since RxJava v2 is not yet ready to be depended upon (still in design flux, rapid code changes, not even a developer preview on Maven Central yet). - -Someday this package should theoretically go away and RxJava v2 directly used. \ No newline at end of file diff --git a/src/main/java/io/reactivesocket/internal/rx/SerializedSubscriber.java b/src/main/java/io/reactivesocket/internal/rx/SerializedSubscriber.java deleted file mode 100644 index fab2efcca..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/SerializedSubscriber.java +++ /dev/null @@ -1,176 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ -package io.reactivesocket.internal.rx; - -import org.reactivestreams.*; - - -/** - * Serializes access to the onNext, onError and onComplete methods of another Subscriber. - * - *

Note that onSubscribe is not serialized in respect of the other methods so - * make sure the Subscription is set before any of the other methods are called. - * - *

The implementation assumes that the actual Subscriber's methods don't throw. - * - * @param the value type - */ -public final class SerializedSubscriber implements Subscriber { - final Subscriber actual; - final boolean delayError; - - static final int QUEUE_LINK_SIZE = 4; - - Subscription subscription; - - boolean emitting; - AppendOnlyLinkedArrayList queue; - - volatile boolean done; - - public SerializedSubscriber(Subscriber actual) { - this(actual, false); - } - - public SerializedSubscriber(Subscriber actual, boolean delayError) { - this.actual = actual; - this.delayError = delayError; - } - @Override - public void onSubscribe(Subscription s) { - if (subscription != null) { - s.cancel(); - onError(new IllegalStateException("Subscription already set!")); - return; - } - this.subscription = s; - - actual.onSubscribe(s); - } - - @Override - public void onNext(T t) { - if (done) { - return; - } - if (t == null) { - subscription.cancel(); - onError(new NullPointerException()); - return; - } - synchronized (this) { - if (done) { - return; - } - if (emitting) { - AppendOnlyLinkedArrayList q = queue; - if (q == null) { - q = new AppendOnlyLinkedArrayList<>(QUEUE_LINK_SIZE); - queue = q; - } - q.add(NotificationLite.next(t)); - return; - } - emitting = true; - } - - actual.onNext(t); - - emitLoop(); - } - - @Override - public void onError(Throwable t) { - if (done) { - return; - } - boolean reportError; - synchronized (this) { - if (done) { - reportError = true; - } else - if (emitting) { - done = true; - AppendOnlyLinkedArrayList q = queue; - if (q == null) { - q = new AppendOnlyLinkedArrayList<>(QUEUE_LINK_SIZE); - queue = q; - } - Object err = NotificationLite.error(t); - if (delayError) { - q.add(err); - } else { - q.setFirst(err); - } - return; - } else { - done = true; - emitting = true; - reportError = false; - } - } - - if (reportError) { - return; - } - - actual.onError(t); - // no need to loop because this onError is the last event - } - - @Override - public void onComplete() { - if (done) { - return; - } - synchronized (this) { - if (done) { - return; - } - if (emitting) { - AppendOnlyLinkedArrayList q = queue; - if (q == null) { - q = new AppendOnlyLinkedArrayList<>(QUEUE_LINK_SIZE); - queue = q; - } - q.add(NotificationLite.complete()); - return; - } - done = true; - emitting = true; - } - - actual.onComplete(); - // no need to loop because this onComplete is the last event - } - - void emitLoop() { - for (;;) { - AppendOnlyLinkedArrayList q; - synchronized (this) { - q = queue; - if (q == null) { - emitting = false; - return; - } - queue = null; - } - - q.forEachWhile(this::accept); - } - } - - boolean accept(Object value) { - return NotificationLite.accept(value, actual); - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/SpscArrayQueue.java b/src/main/java/io/reactivesocket/internal/rx/SpscArrayQueue.java deleted file mode 100644 index 348551e68..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/SpscArrayQueue.java +++ /dev/null @@ -1,133 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -/* - * The code was inspired by the similarly named JCTools class: - * https://github.com/JCTools/JCTools/blob/master/jctools-core/src/main/java/org/jctools/queues/atomic - */ - -package io.reactivesocket.internal.rx; - -import java.util.concurrent.atomic.AtomicLong; - -/** - * A Single-Producer-Single-Consumer queue backed by a pre-allocated buffer. - *

- * This implementation is a mashup of the Fast Flow - * algorithm with an optimization of the offer method taken from the BQueue algorithm (a variation on Fast - * Flow), and adjusted to comply with Queue.offer semantics with regards to capacity.
- * For convenience the relevant papers are available in the resources folder:
- * 2010 - Pisa - SPSC Queues on Shared Cache Multi-Core Systems.pdf
- * 2012 - Junchang- BQueue- Efficient and Practical Queuing.pdf
- *
This implementation is wait free. - * - * @param - */ -public final class SpscArrayQueue extends BaseArrayQueue { - /** */ - private static final long serialVersionUID = -1296597691183856449L; - private static final Integer MAX_LOOK_AHEAD_STEP = Integer.getInteger("jctools.spsc.max.lookahead.step", 4096); - final AtomicLong producerIndex; - protected long producerLookAhead; - final AtomicLong consumerIndex; - final int lookAheadStep; - public SpscArrayQueue(int capacity) { - super(capacity); - this.producerIndex = new AtomicLong(); - this.consumerIndex = new AtomicLong(); - lookAheadStep = Math.min(capacity / 4, MAX_LOOK_AHEAD_STEP); - } - - @Override - public boolean offer(E e) { - if (null == e) { - throw new NullPointerException("Null is not a valid element"); - } - // local load of field to avoid repeated loads after volatile reads - final int mask = this.mask; - final long index = producerIndex.get(); - final int offset = calcElementOffset(index, mask); - if (index >= producerLookAhead) { - int step = lookAheadStep; - if (null == lvElement(calcElementOffset(index + step, mask))) {// LoadLoad - producerLookAhead = index + step; - } - else if (null != lvElement(offset)){ - return false; - } - } - soProducerIndex(index + 1); // ordered store -> atomic and ordered for size() - soElement(offset, e); // StoreStore - return true; - } - - @Override - public E poll() { - final long index = consumerIndex.get(); - final int offset = calcElementOffset(index); - // local load of field to avoid repeated loads after volatile reads - final E e = lvElement(offset);// LoadLoad - if (null == e) { - return null; - } - soConsumerIndex(index + 1); // ordered store -> atomic and ordered for size() - soElement(offset, null);// StoreStore - return e; - } - - @Override - public E peek() { - return lvElement(calcElementOffset(consumerIndex.get())); - } - - @Override - public boolean isEmpty() { - return producerIndex.get() == consumerIndex.get(); - } - - @Override - public int size() { - /* - * It is possible for a thread to be interrupted or reschedule between the read of the producer and consumer - * indices, therefore protection is required to ensure size is within valid range. In the event of concurrent - * polls/offers to this method the size is OVER estimated as we read consumer index BEFORE the producer index. - */ - long after = lvConsumerIndex(); - while (true) { - final long before = after; - final long currentProducerIndex = lvProducerIndex(); - after = lvConsumerIndex(); - if (before == after) { - return (int) (currentProducerIndex - after); - } - } - } - - private void soProducerIndex(long newIndex) { - producerIndex.lazySet(newIndex); - } - - private void soConsumerIndex(long newIndex) { - consumerIndex.lazySet(newIndex); - } - - private long lvConsumerIndex() { - return consumerIndex.get(); - } - private long lvProducerIndex() { - return producerIndex.get(); - } - -} - diff --git a/src/main/java/io/reactivesocket/internal/rx/SpscExactArrayQueue.java b/src/main/java/io/reactivesocket/internal/rx/SpscExactArrayQueue.java deleted file mode 100644 index 41b3664b3..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/SpscExactArrayQueue.java +++ /dev/null @@ -1,164 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -/* - * The code was inspired by the similarly named JCTools class: - * https://github.com/JCTools/JCTools/blob/master/jctools-core/src/main/java/org/jctools/queues/atomic - */ - -package io.reactivesocket.internal.rx; - -import java.util.*; -import java.util.concurrent.atomic.*; - -/** - * A single-producer single-consumer array-backed queue with exact, non power-of-2 logical capacity. - */ -public final class SpscExactArrayQueue extends AtomicReferenceArray implements Queue { - /** */ - private static final long serialVersionUID = 6210984603741293445L; - final int mask; - final int capacitySkip; - volatile long producerIndex; - volatile long consumerIndex; - - @SuppressWarnings("rawtypes") - static final AtomicLongFieldUpdater PRODUCER_INDEX = - AtomicLongFieldUpdater.newUpdater(SpscExactArrayQueue.class, "producerIndex"); - @SuppressWarnings("rawtypes") - static final AtomicLongFieldUpdater CONSUMER_INDEX = - AtomicLongFieldUpdater.newUpdater(SpscExactArrayQueue.class, "consumerIndex"); - - public SpscExactArrayQueue(int capacity) { - super(Pow2.roundToPowerOfTwo(capacity)); - int len = length(); - this.mask = len - 1; - this.capacitySkip = len - capacity; - } - - - @Override - public boolean offer(T value) { - Objects.requireNonNull(value); - - long pi = producerIndex; - int m = mask; - - int fullCheck = (int)(pi + capacitySkip) & m; - if (get(fullCheck) != null) { - return false; - } - int offset = (int)pi & m; - PRODUCER_INDEX.lazySet(this, pi + 1); - lazySet(offset, value); - return true; - } - @Override - public T poll() { - long ci = consumerIndex; - int offset = (int)ci & mask; - T value = get(offset); - if (value == null) { - return null; - } - CONSUMER_INDEX.lazySet(this, ci + 1); - lazySet(offset, null); - return value; - } - @Override - public T peek() { - return get((int)consumerIndex & mask); - } - @Override - public void clear() { - while (poll() != null || !isEmpty()); - } - @Override - public boolean isEmpty() { - return producerIndex == consumerIndex; - } - - @Override - public int size() { - long ci = consumerIndex; - for (;;) { - long pi = producerIndex; - long ci2 = consumerIndex; - if (ci == ci2) { - return (int)(pi - ci2); - } - ci = ci2; - } - } - - @Override - public boolean contains(Object o) { - throw new UnsupportedOperationException(); - } - - @Override - public Iterator iterator() { - throw new UnsupportedOperationException(); - } - - @Override - public Object[] toArray() { - throw new UnsupportedOperationException(); - } - - @Override - public E[] toArray(E[] a) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean remove(Object o) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean containsAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean addAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean removeAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean retainAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean add(T e) { - throw new UnsupportedOperationException(); - } - - @Override - public T remove() { - throw new UnsupportedOperationException(); - } - - @Override - public T element() { - throw new UnsupportedOperationException(); - } - -} diff --git a/src/main/java/io/reactivesocket/internal/rx/SubscriptionArbiter.java b/src/main/java/io/reactivesocket/internal/rx/SubscriptionArbiter.java deleted file mode 100644 index 233ab3472..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/SubscriptionArbiter.java +++ /dev/null @@ -1,188 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -package io.reactivesocket.internal.rx; -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -import java.util.*; -import java.util.concurrent.atomic.*; - -import org.reactivestreams.Subscription; - -/** - * Arbitrates requests and cancellation between Subscriptions. - */ -public final class SubscriptionArbiter extends AtomicInteger implements Subscription { - /** */ - private static final long serialVersionUID = -2189523197179400958L; - - final Queue missedSubscription = new MpscLinkedQueue<>(); - - Subscription actual; - long requested; - - volatile boolean cancelled; - - volatile long missedRequested; - static final AtomicLongFieldUpdater MISSED_REQUESTED = - AtomicLongFieldUpdater.newUpdater(SubscriptionArbiter.class, "missedRequested"); - - volatile long missedProduced; - static final AtomicLongFieldUpdater MISSED_PRODUCED = - AtomicLongFieldUpdater.newUpdater(SubscriptionArbiter.class, "missedProduced"); - - private long addRequested(long n) { - long r = requested; - long u = BackpressureHelper.addCap(r, n); - requested = u; - return r; - } - - @Override - public void request(long n) { - if (SubscriptionHelper.validateRequest(n)) { - return; - } - if (cancelled) { - return; - } - QueueDrainHelper.queueDrainLoop(this, () -> { - addRequested(n); - Subscription s = actual; - if (s != null) { - s.request(n); - } - }, () -> { - BackpressureHelper.add(MISSED_REQUESTED, this, n); - }, this::drain); - } - - public void produced(long n) { - if (n <= 0) { - return; - } - QueueDrainHelper.queueDrainLoop(this, () -> { - long r = requested; - if (r == Long.MAX_VALUE) { - return; - } - long u = r - n; - if (u < 0L) { - u = 0; - } - requested = u; - }, () -> { - BackpressureHelper.add(MISSED_PRODUCED, this, n); - }, this::drain); - } - - public void setSubscription(Subscription s) { - Objects.requireNonNull(s); - if (cancelled) { - s.cancel(); - return; - } - QueueDrainHelper.queueDrainLoop(this, () -> { - Subscription a = actual; - if (a != null) { - a.cancel(); - } - actual = s; - long r = requested; - if (r != 0L) { - s.request(r); - } - }, () -> { - missedSubscription.offer(s); - }, this::drain); - } - - @Override - public void cancel() { - if (cancelled) { - return; - } - cancelled = true; - QueueDrainHelper.queueDrainLoop(this, () -> { - Subscription a = actual; - if (a != null) { - actual = null; - a.cancel(); - } - }, () -> { - // nothing to queue - }, this::drain); - } - - public boolean isCancelled() { - return cancelled; - } - - void drain() { - long mr = MISSED_REQUESTED.getAndSet(this, 0L); - long mp = MISSED_PRODUCED.getAndSet(this, 0L); - Subscription ms = missedSubscription.poll(); - boolean c = cancelled; - - long r = requested; - if (r != Long.MAX_VALUE && !c) { - long u = r + mr; - if (u < 0L) { - r = Long.MAX_VALUE; - requested = Long.MAX_VALUE; - } else { - long v = u - mp; - if (v < 0L) { - v = 0L; - } - r = v; - requested = v; - } - } - - Subscription a = actual; - if (c && a != null) { - actual = null; - a.cancel(); - } - - if (ms == null) { - if (a != null && mr != 0L) { - a.request(mr); - } - } else { - if (c) { - ms.cancel(); - } else { - if (a != null) { - a.cancel(); - } - actual = ms; - if (r != 0L) { - ms.request(r); - } - } - } - } -} diff --git a/src/main/java/io/reactivesocket/internal/rx/SubscriptionHelper.java b/src/main/java/io/reactivesocket/internal/rx/SubscriptionHelper.java deleted file mode 100644 index ad72a8d57..000000000 --- a/src/main/java/io/reactivesocket/internal/rx/SubscriptionHelper.java +++ /dev/null @@ -1,79 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is - * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See - * the License for the specific language governing permissions and limitations under the License. - */ - -package io.reactivesocket.internal.rx; - -import org.reactivestreams.*; - -public enum SubscriptionHelper { - ; - - public static boolean validateSubscription(Subscription current, Subscription next) { - if (next == null) { - return true; - } - if (current != null) { - next.cancel(); - return true; - } - return false; - } - - /** - *

- * Make sure error reporting via s.onError is serialized. - * - * @param current - * @param next - * @param s - * @return - */ - public static boolean validateSubscription(Subscription current, Subscription next, Subscriber s) { - if (next == null) { - s.onError(new NullPointerException("next is null")); - return true; - } - if (current != null) { - next.cancel(); - return true; - } - return false; - } - - public static boolean validateRequest(long n) { - if (n <= 0) { - return true; - } - return false; - } - - /** - *

- * Make sure error reporting via s.onError is serialized. - * - * @param n - * @param current - * @param s - * @return - */ - public static boolean validateRequest(long n, Subscription current, Subscriber s) { - if (n <= 0) { - if (current != null) { - current.cancel(); - } - s.onError(new IllegalArgumentException("n > 0 required but it was " + n)); - return true; - } - return false; - } -} diff --git a/src/main/java/io/reactivesocket/lease/FairLeaseGovernor.java b/src/main/java/io/reactivesocket/lease/FairLeaseGovernor.java deleted file mode 100644 index 05eb37c84..000000000 --- a/src/main/java/io/reactivesocket/lease/FairLeaseGovernor.java +++ /dev/null @@ -1,75 +0,0 @@ -package io.reactivesocket.lease; - -import io.reactivesocket.Frame; -import io.reactivesocket.LeaseGovernor; -import io.reactivesocket.internal.Responder; - -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; - -/** - * Distribute evenly a static number of tickets to all connected clients. - */ -public class FairLeaseGovernor implements LeaseGovernor { - private static ScheduledExecutorService EXECUTOR = Executors.newScheduledThreadPool(1); - - private final int tickets; - private final long period; - private final TimeUnit unit; - private final Map responders; - private ScheduledFuture runningTask; - - private synchronized void distribute(int ttlMs) { - if (!responders.isEmpty()) { - int budget = tickets / responders.size(); - - // it would be more fair to randomized the distribution of extra - int extra = tickets - budget * responders.size(); - for (Responder responder: responders.keySet()) { - int n = budget; - if (extra > 0) { - n += 1; - extra -= 1; - } - responder.sendLease(ttlMs, n); - responders.put(responder, n); - } - } - } - - public FairLeaseGovernor(int tickets, long period, TimeUnit unit) { - this.tickets = tickets; - this.period = period; - this.unit = unit; - responders = new HashMap<>(); - } - - @Override - public synchronized void register(Responder responder) { - responders.put(responder, 0); - if (runningTask == null) { - final int ttl = (int)TimeUnit.NANOSECONDS.convert(period, unit); - runningTask = EXECUTOR.scheduleAtFixedRate(() -> distribute(ttl), 0, period, unit); - } - } - - @Override - public synchronized void unregister(Responder responder) { - responders.remove(responder); - if (responders.isEmpty() && runningTask != null) { - runningTask.cancel(true); - runningTask = null; - } - } - - @Override - public synchronized boolean accept(Responder responder, Frame frame) { - boolean valid; - final Integer remainingTickets = responders.get(responder); - return remainingTickets == null || remainingTickets > 0; - } -} diff --git a/src/main/java/io/reactivesocket/lease/NullLeaseGovernor.java b/src/main/java/io/reactivesocket/lease/NullLeaseGovernor.java deleted file mode 100644 index a08fc1bac..000000000 --- a/src/main/java/io/reactivesocket/lease/NullLeaseGovernor.java +++ /dev/null @@ -1,18 +0,0 @@ -package io.reactivesocket.lease; - -import io.reactivesocket.Frame; -import io.reactivesocket.LeaseGovernor; -import io.reactivesocket.internal.Responder; - -public class NullLeaseGovernor implements LeaseGovernor { - @Override - public void register(Responder responder) {} - - @Override - public void unregister(Responder responder) {} - - @Override - public boolean accept(Responder responder, Frame frame) { - return true; - } -} diff --git a/src/main/java/io/reactivesocket/lease/UnlimitedLeaseGovernor.java b/src/main/java/io/reactivesocket/lease/UnlimitedLeaseGovernor.java deleted file mode 100644 index 3cff13ff6..000000000 --- a/src/main/java/io/reactivesocket/lease/UnlimitedLeaseGovernor.java +++ /dev/null @@ -1,20 +0,0 @@ -package io.reactivesocket.lease; - -import io.reactivesocket.Frame; -import io.reactivesocket.LeaseGovernor; -import io.reactivesocket.internal.Responder; - -public class UnlimitedLeaseGovernor implements LeaseGovernor { - @Override - public void register(Responder responder) { - responder.sendLease(Integer.MAX_VALUE, Integer.MAX_VALUE); - } - - @Override - public void unregister(Responder responder) {} - - @Override - public boolean accept(Responder responder, Frame frame) { - return true; - } -} diff --git a/src/main/java/io/reactivesocket/rx/Completable.java b/src/main/java/io/reactivesocket/rx/Completable.java deleted file mode 100644 index 9f87f6a11..000000000 --- a/src/main/java/io/reactivesocket/rx/Completable.java +++ /dev/null @@ -1,9 +0,0 @@ -package io.reactivesocket.rx; - -public interface Completable { - - public abstract void success(); - - public abstract void error(Throwable e); - -} diff --git a/src/main/java/io/reactivesocket/rx/Disposable.java b/src/main/java/io/reactivesocket/rx/Disposable.java deleted file mode 100644 index df6efcda7..000000000 --- a/src/main/java/io/reactivesocket/rx/Disposable.java +++ /dev/null @@ -1,7 +0,0 @@ -package io.reactivesocket.rx; - -public interface Disposable { - - public void dispose(); - -} diff --git a/src/main/java/io/reactivesocket/rx/Observable.java b/src/main/java/io/reactivesocket/rx/Observable.java deleted file mode 100644 index 9c5d6e39d..000000000 --- a/src/main/java/io/reactivesocket/rx/Observable.java +++ /dev/null @@ -1,6 +0,0 @@ -package io.reactivesocket.rx; - -public interface Observable { - - public void subscribe(Observer o); -} diff --git a/src/main/java/io/reactivesocket/rx/Observer.java b/src/main/java/io/reactivesocket/rx/Observer.java deleted file mode 100644 index 5a8bafde7..000000000 --- a/src/main/java/io/reactivesocket/rx/Observer.java +++ /dev/null @@ -1,12 +0,0 @@ -package io.reactivesocket.rx; - -public interface Observer { - - public void onNext(T t); - - public void onError(Throwable e); - - public void onComplete(); - - public void onSubscribe(Disposable d); -} diff --git a/src/main/java/io/reactivesocket/rx/README.md b/src/main/java/io/reactivesocket/rx/README.md deleted file mode 100644 index e75d96494..000000000 --- a/src/main/java/io/reactivesocket/rx/README.md +++ /dev/null @@ -1,3 +0,0 @@ -Interfaces for `Observable` that does not support backpressure. - -TODO: Decide if we just use concrete types from RxJava 2 once this type exists. (Flowable vs Observable) (BenC would prefer this package go away) \ No newline at end of file diff --git a/src/perf/java/io/reactivesocket/FramePerf.java b/src/perf/java/io/reactivesocket/FramePerf.java deleted file mode 100644 index 528ffead5..000000000 --- a/src/perf/java/io/reactivesocket/FramePerf.java +++ /dev/null @@ -1,106 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import java.nio.ByteBuffer; -import java.util.concurrent.TimeUnit; - -import org.apache.commons.math3.stat.inference.TestUtils; -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.BenchmarkMode; -import org.openjdk.jmh.annotations.Mode; -import org.openjdk.jmh.annotations.OutputTimeUnit; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.infra.Blackhole; - -@BenchmarkMode(Mode.Throughput) -@OutputTimeUnit(TimeUnit.SECONDS) -public class FramePerf { - - public static Frame utf8EncodedFrame(final int streamId, final FrameType type, final String data) - { - final byte[] bytes = data.getBytes(); - final ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); - final Payload payload = new Payload() - { - public ByteBuffer getData() - { - return byteBuffer; - } - - public ByteBuffer getMetadata() - { - return Frame.NULL_BYTEBUFFER; - } - }; - - return Frame.Response.from(streamId, type, payload); - } - - /** - * Test encoding of "hello" frames/second with a new string->byte encoding each time - * - * @param input - * @return - * @throws InterruptedException - */ - @Benchmark - public Frame encodeNextCompleteHello(Input input) throws InterruptedException { - return utf8EncodedFrame(0, FrameType.NEXT_COMPLETE, "hello"); - } - - /** - * Test encoding of Frame without any overhead with byte[] or ByteBuffer by reusing the same ByteBuffer - * - * @param input - * @return - */ - @Benchmark - public Frame encodeStaticHelloIntoFrame(Input input) { - input.HELLO.position(0); - return Frame.Response.from(0, FrameType.NEXT_COMPLETE, input.HELLOpayload); - } - - @State(Scope.Thread) - public static class Input { - /** - * Use to consume values when the test needs to return more than a single value. - */ - public Blackhole bh; - - public ByteBuffer HELLO = ByteBuffer.wrap("HELLO".getBytes()); - public Payload HELLOpayload = new Payload() - { - public ByteBuffer getData() - { - return HELLO; - } - - public ByteBuffer getMetadata() - { - return Frame.NULL_BYTEBUFFER; - } - }; - - @Setup - public void setup(Blackhole bh) { - this.bh = bh; - } - } - -} diff --git a/src/perf/java/io/reactivesocket/README.md b/src/perf/java/io/reactivesocket/README.md deleted file mode 100644 index ed7926d78..000000000 --- a/src/perf/java/io/reactivesocket/README.md +++ /dev/null @@ -1,54 +0,0 @@ -# JMH Benchmarks - -### Run All - -``` -./gradlew benchmarks -``` - -### Run Specific Class - -``` -./gradlew benchmarks '-Pjmh=.*FramePerf.*' -``` - -### Arguments - -Optionally pass arguments for custom execution. Example: - -``` -./gradlew benchmarks '-Pjmh=-f 1 -tu s -bm thrpt -wi 5 -i 5 -r 1 .*FramePerf.*' -``` - -gives output like this: - -``` -# Warmup Iteration 1: 12699094.396 ops/s -# Warmup Iteration 2: 15101768.843 ops/s -# Warmup Iteration 3: 14991750.686 ops/s -# Warmup Iteration 4: 14819319.785 ops/s -# Warmup Iteration 5: 14856301.193 ops/s -Iteration 1: 14910334.272 ops/s -Iteration 2: 14954589.540 ops/s -Iteration 3: 15076277.267 ops/s -Iteration 4: 14833413.303 ops/s -Iteration 5: 14893188.328 ops/s - - -Result "encodeNextCompleteHello": - 14933560.542 ±(99.9%) 349800.467 ops/s [Average] - (min, avg, max) = (14833413.303, 14933560.542, 15076277.267), stdev = 90842.071 - CI (99.9%): [14583760.075, 15283361.009] (assumes normal distribution) - - -# Run complete. Total time: 00:00:10 - -Benchmark Mode Cnt Score Error Units -FramePerf.encodeNextCompleteHello thrpt 5 14933560.542 ± 349800.467 ops/s -``` - -To see all options: - -``` -./gradlew benchmarks '-Pjmh=-h' -``` diff --git a/src/perf/java/io/reactivesocket/ReactiveSocketPerf.java b/src/perf/java/io/reactivesocket/ReactiveSocketPerf.java deleted file mode 100644 index 9378cf9d9..000000000 --- a/src/perf/java/io/reactivesocket/ReactiveSocketPerf.java +++ /dev/null @@ -1,289 +0,0 @@ -package io.reactivesocket; - -import io.reactivesocket.internal.PublisherUtils; -import io.reactivesocket.perfutil.PerfTestConnection; -import io.reactivesocket.rx.Completable; -import org.openjdk.jmh.annotations.Benchmark; -import org.openjdk.jmh.annotations.BenchmarkMode; -import org.openjdk.jmh.annotations.Mode; -import org.openjdk.jmh.annotations.OutputTimeUnit; -import org.openjdk.jmh.annotations.Scope; -import org.openjdk.jmh.annotations.Setup; -import org.openjdk.jmh.annotations.State; -import org.openjdk.jmh.infra.Blackhole; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import java.nio.ByteBuffer; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -@BenchmarkMode(Mode.Throughput) -@OutputTimeUnit(TimeUnit.SECONDS) -public class ReactiveSocketPerf { - - @Benchmark - public void requestResponseHello(Input input) { - // this is synchronous so we don't need to use a CountdownLatch to wait - Input.client.requestResponse(Input.HELLO_PAYLOAD).subscribe(input.blackholeConsumer); - } - - @Benchmark - public void requestStreamHello1000(Input input) { - // this is synchronous so we don't need to use a CountdownLatch to wait - Input.client.requestStream(Input.HELLO_PAYLOAD).subscribe(input.blackholeConsumer); - } - - @Benchmark - public void fireAndForgetHello(Input input) { - // this is synchronous so we don't need to use a CountdownLatch to wait - Input.client.fireAndForget(Input.HELLO_PAYLOAD).subscribe(input.voidBlackholeConsumer); - } - - @State(Scope.Thread) - public static class Input { - /** - * Use to consume values when the test needs to return more than a single value. - */ - public Blackhole bh; - - static final ByteBuffer HELLO = ByteBuffer.wrap("HELLO".getBytes()); - static final ByteBuffer HELLO_WORLD = ByteBuffer.wrap("HELLO_WORLD".getBytes()); - static final ByteBuffer EMPTY = ByteBuffer.allocate(0); - - static final Payload HELLO_PAYLOAD = new Payload() { - - @Override - public ByteBuffer getMetadata() { - return EMPTY; - } - - @Override - public ByteBuffer getData() { - HELLO.position(0); - return HELLO; - } - }; - - static final Payload HELLO_WORLD_PAYLOAD = new Payload() { - - @Override - public ByteBuffer getMetadata() { - return EMPTY; - } - - @Override - public ByteBuffer getData() { - HELLO_WORLD.position(0); - return HELLO_WORLD; - } - }; - - final static PerfTestConnection serverConnection = new PerfTestConnection(); - final static PerfTestConnection clientConnection = new PerfTestConnection(); - - static { - clientConnection.connectToServerConnection(serverConnection); - } - - private static Publisher HELLO_1 = just(HELLO_WORLD_PAYLOAD); - private static Publisher HELLO_1000; - - static { - Payload[] ps = new Payload[1000]; - for (int i = 0; i < ps.length; i++) { - ps[i] = HELLO_WORLD_PAYLOAD; - } - HELLO_1000 = just(ps); - } - - static final RequestHandler handler = new RequestHandler() { - - @Override - public Publisher handleRequestResponse(Payload payload) { - return HELLO_1; - } - - @Override - public Publisher handleRequestStream(Payload payload) { - return HELLO_1000; - } - - @Override - public Publisher handleSubscription(Payload payload) { - return null; - } - - @Override - public Publisher handleFireAndForget(Payload payload) { - return PublisherUtils.empty(); - } - - @Override - public Publisher handleChannel(Payload initialPayload, Publisher inputs) { - return null; - } - - @Override - public Publisher handleMetadataPush(Payload payload) - { - return null; - } - }; - - final static ReactiveSocket serverSocket = DefaultReactiveSocket.fromServerConnection(serverConnection, (setup, rs) -> handler); - - final static ReactiveSocket client = - DefaultReactiveSocket.fromClientConnection( - clientConnection, ConnectionSetupPayload.create("UTF-8", "UTF-8", ConnectionSetupPayload.NO_FLAGS), t -> {}); - - static { - LatchedCompletable lc = new LatchedCompletable(2); - serverSocket.start(lc); - client.start(lc); - try { - lc.latch.await(); - } catch (InterruptedException e) { - throw new RuntimeException("Failed waiting on startup", e); - } - } - - Subscriber blackholeConsumer; // reuse this each time - Subscriber voidBlackholeConsumer; // reuse this each time - - @Setup - public void setup(Blackhole bh) { - this.bh = bh; - blackholeConsumer = new Subscriber() { - - @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Payload t) { - bh.consume(t); - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - } - - @Override - public void onComplete() { - - } - - }; - - voidBlackholeConsumer = new Subscriber() { - - @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Void t) { - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - } - - @Override - public void onComplete() { - - } - - }; - } - } - - private static Publisher just(Payload... ps) { - return new Publisher() { - - @Override - public void subscribe(Subscriber s) { - s.onSubscribe(new Subscription() { - - int emitted = 0; - - @Override - public void request(long n) { - // NOTE: This is not a safe implementation as it assumes synchronous request(n) - if (emitted == ps.length) { - s.onComplete(); - return; - } - long _n = Math.min(n, ps.length); - for (int i = 0; i < _n; i++) { - s.onNext(ps[emitted++]); - if (emitted == ps.length) { - s.onComplete(); - break; - } - } - } - - @Override - public void cancel() { - - } - - }); - } - - }; - } - - private static class ErrorSubscriber implements Subscriber { - - @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(T t) { - - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - } - - @Override - public void onComplete() { - - } - - } - - private static class LatchedCompletable implements Completable { - - final CountDownLatch latch; - - LatchedCompletable(int count) { - this.latch = new CountDownLatch(count); - } - - @Override - public void success() { - latch.countDown(); - } - - @Override - public void error(Throwable e) { - System.err.println("Error waiting for Requester"); - e.printStackTrace(); - latch.countDown(); - } - - }; -} diff --git a/src/perf/java/io/reactivesocket/perfutil/PerfTestConnection.java b/src/perf/java/io/reactivesocket/perfutil/PerfTestConnection.java deleted file mode 100644 index 1d05f2eec..000000000 --- a/src/perf/java/io/reactivesocket/perfutil/PerfTestConnection.java +++ /dev/null @@ -1,82 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.perfutil; - -import java.io.IOException; - -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import io.reactivesocket.DuplexConnection; -import io.reactivesocket.Frame; -import io.reactivesocket.rx.Completable; -import io.reactivesocket.rx.Observable; - -public class PerfTestConnection implements DuplexConnection { - - public final PerfUnicastSubjectNoBackpressure toInput = PerfUnicastSubjectNoBackpressure.create(); - private PerfUnicastSubjectNoBackpressure writeSubject = PerfUnicastSubjectNoBackpressure.create(); - - @Override - public void addOutput(Publisher o, Completable callback) { - o.subscribe(new Subscriber() { - - @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Frame f) { - writeSubject.onNext(f); - } - - @Override - public void onError(Throwable t) { - callback.error(t); - } - - @Override - public void onComplete() { - callback.success(); - } - - }); - } - - @Override - public void addOutput(Frame f, Completable callback) { - writeSubject.onNext(f); - callback.success(); - } - - @Override - public Observable getInput() { - return toInput; - } - - public void connectToServerConnection(PerfTestConnection serverConnection) { - writeSubject.subscribe(serverConnection.toInput); - serverConnection.writeSubject.subscribe(toInput); - - } - - @Override - public void close() throws IOException { - - } -} \ No newline at end of file diff --git a/src/perf/java/io/reactivesocket/perfutil/PerfUnicastSubjectNoBackpressure.java b/src/perf/java/io/reactivesocket/perfutil/PerfUnicastSubjectNoBackpressure.java deleted file mode 100644 index 2f1a5945d..000000000 --- a/src/perf/java/io/reactivesocket/perfutil/PerfUnicastSubjectNoBackpressure.java +++ /dev/null @@ -1,92 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.perfutil; - -import java.util.function.Consumer; - -import io.reactivesocket.rx.Disposable; -import io.reactivesocket.rx.Observable; -import io.reactivesocket.rx.Observer; - -/** - * The difference between this and the real UnicastSubject is in the `onSubscribe` method where it calls requestN. Not sure that behavior should exist in the producton code. - */ -public final class PerfUnicastSubjectNoBackpressure implements Observable, Observer { - - private Observer s; - private final Consumer> onConnect; - private boolean subscribedTo = false; - - public static PerfUnicastSubjectNoBackpressure create() { - return new PerfUnicastSubjectNoBackpressure<>(null); - } - - /** - * @param onConnect Called when first requestN > 0 occurs. - * @return - */ - public static PerfUnicastSubjectNoBackpressure create(Consumer> onConnect) { - return new PerfUnicastSubjectNoBackpressure<>(onConnect); - } - - private PerfUnicastSubjectNoBackpressure(Consumer> onConnect) { - this.onConnect = onConnect; - } - - @Override - public void onSubscribe(Disposable s) { - } - - @Override - public void onNext(T t) { - s.onNext(t); - } - - @Override - public void onError(Throwable t) { - s.onError(t); - } - - @Override - public void onComplete() { - s.onComplete(); - } - - @Override - public void subscribe(Observer s) { - if (this.s != null) { - s.onError(new IllegalStateException("Only single Subscriber supported")); - } else { - this.s = s; - this.s.onSubscribe(new Disposable() { - - @Override - public void dispose() { - // transport has shut us down - } - - }); - if(onConnect != null) { - onConnect.accept(PerfUnicastSubjectNoBackpressure.this); - } - } - } - - public boolean isSubscribedTo() { - return subscribedTo; - } - -} diff --git a/src/test/java/io/reactivesocket/FrameTest.java b/src/test/java/io/reactivesocket/FrameTest.java deleted file mode 100644 index 03efb86eb..000000000 --- a/src/test/java/io/reactivesocket/FrameTest.java +++ /dev/null @@ -1,535 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import static org.junit.Assert.*; - -import java.nio.ByteBuffer; -import java.util.concurrent.TimeUnit; - -import io.reactivesocket.exceptions.RejectedException; -import io.reactivesocket.internal.frame.SetupFrameFlyweight; - -import org.junit.Test; -import org.junit.experimental.theories.DataPoint; -import org.junit.experimental.theories.Theories; -import org.junit.experimental.theories.Theory; -import org.junit.runner.RunWith; -import org.agrona.concurrent.UnsafeBuffer; - -import static io.reactivesocket.internal.frame.ErrorFrameFlyweight.*; -import static java.nio.charset.StandardCharsets.UTF_8; - -@RunWith(Theories.class) -public class FrameTest -{ - private static Payload createPayload(final ByteBuffer metadata, final ByteBuffer data) - { - return new Payload() - { - public ByteBuffer getData() - { - return data; - } - - public ByteBuffer getMetadata() - { - return metadata; - } - }; - } - - @DataPoint - public static final int ZERO_OFFSET = 0; - - @DataPoint - public static final int NON_ZERO_OFFSET = 127; - - private static final UnsafeBuffer reusableMutableDirectBuffer = new UnsafeBuffer(ByteBuffer.allocate(1024)); - private static final Frame reusableFrame = Frame.allocate(reusableMutableDirectBuffer); - - @Test - public void testWriteThenRead() { - final ByteBuffer helloBuffer = TestUtil.byteBufferFromUtf8String("hello"); - final Payload payload = createPayload(Frame.NULL_BYTEBUFFER, helloBuffer); - - Frame f = Frame.Request.from(1, FrameType.REQUEST_RESPONSE, payload, 1); - - assertEquals("hello", TestUtil.byteToString(f.getData())); - assertEquals(FrameType.REQUEST_RESPONSE, f.getType()); - assertEquals(1, f.getStreamId()); - - ByteBuffer b = f.getByteBuffer(); - - Frame f2 = Frame.from(b); - assertEquals("hello", TestUtil.byteToString(f2.getData())); - assertEquals(FrameType.REQUEST_RESPONSE, f2.getType()); - assertEquals(1, f2.getStreamId()); - } - - @Test - public void testWrapMessage() { - final ByteBuffer helloBuffer = TestUtil.byteBufferFromUtf8String("hello"); - final ByteBuffer doneBuffer = TestUtil.byteBufferFromUtf8String("done"); - final Payload payload = createPayload(Frame.NULL_BYTEBUFFER, helloBuffer); - - Frame f = Frame.Request.from(1, FrameType.REQUEST_RESPONSE, payload, 1); - - f.wrap(2, FrameType.COMPLETE, doneBuffer); - assertEquals("done", TestUtil.byteToString(f.getData())); - assertEquals(FrameType.NEXT_COMPLETE, f.getType()); - assertEquals(2, f.getStreamId()); - } - - @Test - public void testWrapBytes() { - final ByteBuffer helloBuffer = TestUtil.byteBufferFromUtf8String("hello"); - final ByteBuffer anotherBuffer = TestUtil.byteBufferFromUtf8String("another"); - final Payload payload = createPayload(Frame.NULL_BYTEBUFFER, helloBuffer); - final Payload anotherPayload = createPayload(Frame.NULL_BYTEBUFFER, anotherBuffer); - - Frame f = Frame.Request.from(1, FrameType.REQUEST_RESPONSE, payload, 1); - Frame f2 = Frame.Response.from(20, FrameType.COMPLETE, anotherPayload); - - ByteBuffer b = f2.getByteBuffer(); - f.wrap(b, 0); - - assertEquals("another", TestUtil.byteToString(f.getData())); - assertEquals(FrameType.NEXT_COMPLETE, f.getType()); - assertEquals(20, f.getStreamId()); - } - - @Test - @Theory - public void shouldReturnCorrectDataPlusMetadataForRequestResponse(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("request data"); - final ByteBuffer requestMetadata = TestUtil.byteBufferFromUtf8String("request metadata"); - final Payload payload = createPayload(requestMetadata, requestData); - - Frame encodedFrame = Frame.Request.from(1, FrameType.REQUEST_RESPONSE, payload, 1); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals(FrameType.REQUEST_RESPONSE, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - assertEquals("request data", TestUtil.byteToString(reusableFrame.getData())); - assertEquals("request metadata", TestUtil.byteToString(reusableFrame.getMetadata())); - } - - @Test - @Theory - public void shouldReturnCorrectDataPlusMetadataForFireAndForget(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("request data"); - final ByteBuffer requestMetadata = TestUtil.byteBufferFromUtf8String("request metadata"); - final Payload payload = createPayload(requestMetadata, requestData); - - Frame encodedFrame = Frame.Request.from(1, FrameType.FIRE_AND_FORGET, payload, 0); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals("request data", TestUtil.byteToString(reusableFrame.getData())); - assertEquals("request metadata", TestUtil.byteToString(reusableFrame.getMetadata())); - assertEquals(FrameType.FIRE_AND_FORGET, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - } - - @Test - @Theory - public void shouldReturnCorrectDataPlusMetadataForRequestStream(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("request data"); - final ByteBuffer requestMetadata = TestUtil.byteBufferFromUtf8String("request metadata"); - final Payload payload = createPayload(requestMetadata, requestData); - - Frame encodedFrame = Frame.Request.from(1, FrameType.REQUEST_STREAM, payload, 128); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals("request data", TestUtil.byteToString(reusableFrame.getData())); - assertEquals("request metadata", TestUtil.byteToString(reusableFrame.getMetadata())); - assertEquals(FrameType.REQUEST_STREAM, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - assertEquals(128, Frame.Request.initialRequestN(reusableFrame)); - } - - @Test - @Theory - public void shouldReturnCorrectDataPlusMetadataForRequestSubscription(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("request data"); - final ByteBuffer requestMetadata = TestUtil.byteBufferFromUtf8String("request metadata"); - final Payload payload = createPayload(requestMetadata, requestData); - - Frame encodedFrame = Frame.Request.from(1, FrameType.REQUEST_SUBSCRIPTION, payload, 128); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals("request data", TestUtil.byteToString(reusableFrame.getData())); - assertEquals("request metadata", TestUtil.byteToString(reusableFrame.getMetadata())); - assertEquals(FrameType.REQUEST_SUBSCRIPTION, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - assertEquals(128, Frame.Request.initialRequestN(reusableFrame)); - } - - @Test - @Theory - public void shouldReturnCorrectDataPlusMetadataForResponse(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("response data"); - final ByteBuffer requestMetadata = TestUtil.byteBufferFromUtf8String("response metadata"); - final Payload payload = createPayload(requestMetadata, requestData); - - Frame encodedFrame = Frame.Response.from(1, FrameType.RESPONSE, payload); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals("response data", TestUtil.byteToString(reusableFrame.getData())); - assertEquals("response metadata", TestUtil.byteToString(reusableFrame.getMetadata())); - assertEquals(FrameType.NEXT, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - } - - @Test - @Theory - public void shouldReturnCorrectDataWithoutMetadataForRequestResponse(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("request data"); - final Payload payload = createPayload(Frame.NULL_BYTEBUFFER, requestData); - - Frame encodedFrame = Frame.Request.from(1, FrameType.REQUEST_RESPONSE, payload, 1); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals("request data", TestUtil.byteToString(reusableFrame.getData())); - - final ByteBuffer metadataBuffer = reusableFrame.getMetadata(); - assertEquals(0, metadataBuffer.capacity()); - assertEquals(FrameType.REQUEST_RESPONSE, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - } - - @Test - @Theory - public void shouldReturnCorrectDataWithoutMetadataForFireAndForget(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("request data"); - final Payload payload = createPayload(Frame.NULL_BYTEBUFFER, requestData); - - Frame encodedFrame = Frame.Request.from(1, FrameType.FIRE_AND_FORGET, payload, 0); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals("request data", TestUtil.byteToString(reusableFrame.getData())); - - final ByteBuffer metadataBuffer = reusableFrame.getMetadata(); - assertEquals(0, metadataBuffer.capacity()); - assertEquals(FrameType.FIRE_AND_FORGET, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - } - - @Test - @Theory - public void shouldReturnCorrectDataWithoutMetadataForRequestStream(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("request data"); - final Payload payload = createPayload(Frame.NULL_BYTEBUFFER, requestData); - - Frame encodedFrame = Frame.Request.from(1, FrameType.REQUEST_STREAM, payload, 128); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals("request data", TestUtil.byteToString(reusableFrame.getData())); - - final ByteBuffer metadataBuffer = reusableFrame.getMetadata(); - assertEquals(0, metadataBuffer.capacity()); - assertEquals(FrameType.REQUEST_STREAM, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - assertEquals(128, Frame.Request.initialRequestN(reusableFrame)); - } - - @Test - @Theory - public void shouldReturnCorrectDataWithoutMetadataForRequestSubscription(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("request data"); - final Payload payload = createPayload(Frame.NULL_BYTEBUFFER, requestData); - - Frame encodedFrame = Frame.Request.from(1, FrameType.REQUEST_SUBSCRIPTION, payload, 128); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals("request data", TestUtil.byteToString(reusableFrame.getData())); - - final ByteBuffer metadataBuffer = reusableFrame.getMetadata(); - assertEquals(0, metadataBuffer.capacity()); - assertEquals(FrameType.REQUEST_SUBSCRIPTION, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - assertEquals(128, Frame.Request.initialRequestN(reusableFrame)); - } - - @Test - @Theory - public void shouldReturnCorrectDataWithoutMetadataForResponse(final int offset) - { - final ByteBuffer requestData = TestUtil.byteBufferFromUtf8String("response data"); - final Payload payload = createPayload(Frame.NULL_BYTEBUFFER, requestData); - - Frame encodedFrame = Frame.Response.from(1, FrameType.RESPONSE, payload); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals("response data", TestUtil.byteToString(reusableFrame.getData())); - - final ByteBuffer metadataBuffer = reusableFrame.getMetadata(); - assertEquals(0, metadataBuffer.capacity()); - assertEquals(FrameType.NEXT, reusableFrame.getType()); - assertEquals(1, reusableFrame.getStreamId()); - } - - @Test - @Theory - public void shouldReturnCorrectDataPlusMetadataForSetup(final int offset) - { - final int flags = SetupFrameFlyweight.FLAGS_WILL_HONOR_LEASE | SetupFrameFlyweight.FLAGS_STRICT_INTERPRETATION; - final int version = SetupFrameFlyweight.CURRENT_VERSION; - final int keepaliveInterval = 1001; - final int maxLifetime = keepaliveInterval * 5; - final String metadataMimeType = "application/json"; - final String dataMimeType = "application/cbor"; - final ByteBuffer setupData = TestUtil.byteBufferFromUtf8String("setup data"); - final ByteBuffer setupMetadata = TestUtil.byteBufferFromUtf8String("setup metadata"); - - Frame encodedFrame = Frame.Setup.from(flags, keepaliveInterval, maxLifetime, metadataMimeType, dataMimeType, new Payload() - { - public ByteBuffer getData() - { - return setupData; - } - - public ByteBuffer getMetadata() - { - return setupMetadata; - } - }); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals(FrameType.SETUP, reusableFrame.getType()); - assertEquals(flags, Frame.Setup.getFlags(reusableFrame)); - assertEquals(version, Frame.Setup.version(reusableFrame)); - assertEquals(keepaliveInterval, Frame.Setup.keepaliveInterval(reusableFrame)); - assertEquals(maxLifetime, Frame.Setup.maxLifetime(reusableFrame)); - assertEquals(metadataMimeType, Frame.Setup.metadataMimeType(reusableFrame)); - assertEquals(dataMimeType, Frame.Setup.dataMimeType(reusableFrame)); - assertEquals("setup data", TestUtil.byteToString(reusableFrame.getData())); - assertEquals("setup metadata", TestUtil.byteToString(reusableFrame.getMetadata())); - } - - @Test - @Theory - public void shouldReturnCorrectDataWithoutMetadataForSetup(final int offset) - { - final int flags = SetupFrameFlyweight.FLAGS_WILL_HONOR_LEASE | SetupFrameFlyweight.FLAGS_STRICT_INTERPRETATION; - final int version = SetupFrameFlyweight.CURRENT_VERSION; - final int keepaliveInterval = 1001; - final int maxLifetime = keepaliveInterval * 5; - final String metadataMimeType = "application/json"; - final String dataMimeType = "application/cbor"; - final ByteBuffer setupData = TestUtil.byteBufferFromUtf8String("setup data"); - - Frame encodedFrame = Frame.Setup.from(flags, keepaliveInterval, maxLifetime, metadataMimeType, dataMimeType, new Payload() - { - public ByteBuffer getData() - { - return setupData; - } - - public ByteBuffer getMetadata() - { - return Frame.NULL_BYTEBUFFER; - } - }); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals(FrameType.SETUP, reusableFrame.getType()); - assertEquals(flags, Frame.Setup.getFlags(reusableFrame)); - assertEquals(version, Frame.Setup.version(reusableFrame)); - assertEquals(keepaliveInterval, Frame.Setup.keepaliveInterval(reusableFrame)); - assertEquals(maxLifetime, Frame.Setup.maxLifetime(reusableFrame)); - assertEquals(metadataMimeType, Frame.Setup.metadataMimeType(reusableFrame)); - assertEquals(dataMimeType, Frame.Setup.dataMimeType(reusableFrame)); - assertEquals("setup data", TestUtil.byteToString(reusableFrame.getData())); - assertEquals(Frame.NULL_BYTEBUFFER, reusableFrame.getMetadata()); - } - - @Test - @Theory - public void shouldFormCorrectlyWithoutDataNorMetadataForSetup(final int offset) - { - final int flags = SetupFrameFlyweight.FLAGS_WILL_HONOR_LEASE | SetupFrameFlyweight.FLAGS_STRICT_INTERPRETATION; - final int version = SetupFrameFlyweight.CURRENT_VERSION; - final int keepaliveInterval = 1001; - final int maxLifetime = keepaliveInterval * 5; - final String metadataMimeType = "application/json"; - final String dataMimeType = "application/cbor"; - - Frame encodedFrame = Frame.Setup.from(flags, keepaliveInterval, maxLifetime, metadataMimeType, dataMimeType, new Payload() - { - public ByteBuffer getData() - { - return Frame.NULL_BYTEBUFFER; - } - - public ByteBuffer getMetadata() - { - return Frame.NULL_BYTEBUFFER; - } - }); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals(FrameType.SETUP, reusableFrame.getType()); - assertEquals(flags, Frame.Setup.getFlags(reusableFrame)); - assertEquals(version, Frame.Setup.version(reusableFrame)); - assertEquals(keepaliveInterval, Frame.Setup.keepaliveInterval(reusableFrame)); - assertEquals(maxLifetime, Frame.Setup.maxLifetime(reusableFrame)); - assertEquals(metadataMimeType, Frame.Setup.metadataMimeType(reusableFrame)); - assertEquals(dataMimeType, Frame.Setup.dataMimeType(reusableFrame)); - assertEquals(Frame.NULL_BYTEBUFFER, reusableFrame.getData()); - assertEquals(Frame.NULL_BYTEBUFFER, reusableFrame.getMetadata()); - } - - @Test - @Theory - public void shouldReturnCorrectDataPlusMetadataForError(final int offset) - { - final int streamId = 24; - final Throwable exception = new RejectedException("test"); - final String data = "error data"; - final String metadata = "error metadata"; - final ByteBuffer dataByteBuffer = ByteBuffer.wrap(data.getBytes(UTF_8)); - final ByteBuffer metadataByteBuffer = ByteBuffer.wrap(metadata.getBytes(UTF_8)); - - Frame encodedFrame = Frame.Error.from(streamId, exception, metadataByteBuffer, dataByteBuffer); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals(FrameType.ERROR, reusableFrame.getType()); - assertEquals(REJECTED, Frame.Error.errorCode(reusableFrame)); - assertEquals(data, TestUtil.byteToString(reusableFrame.getData())); - assertEquals(metadata, TestUtil.byteToString(reusableFrame.getMetadata())); - } - - @Test - @Theory - public void shouldReturnCorrectDataWithThrowableForError(final int offset) - { - final int errorCode = 42; - final String metadata = "my metadata"; - final String exMessage = "exception message"; - - Frame encodedFrame = Frame.Error.from( - errorCode, - new Exception(exMessage), - TestUtil.byteBufferFromUtf8String(metadata) - ); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - - assertEquals(FrameType.ERROR, reusableFrame.getType()); - assertEquals(exMessage, TestUtil.byteToString(reusableFrame.getData())); - assertEquals(TestUtil.byteBufferFromUtf8String(metadata), reusableFrame.getMetadata()); - } - - @Test - @Theory - public void shouldReturnCorrectDataWithoutMetadataForError(final int offset) - { - final int errorCode = 42; - final String metadata = "metadata"; - final String data = "error data"; - - Frame encodedFrame = Frame.Error.from( - errorCode, - new Exception("my exception"), - TestUtil.byteBufferFromUtf8String(metadata), - TestUtil.byteBufferFromUtf8String(data) - ); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals(FrameType.ERROR, reusableFrame.getType()); - assertEquals(data, TestUtil.byteToString(reusableFrame.getData())); - assertEquals(metadata, TestUtil.byteToString(reusableFrame.getMetadata())); - } - - @Test - @Theory - public void shouldFormCorrectlyForRequestN(final int offset) - { - final int n = 128; - final Frame encodedFrame = Frame.RequestN.from(1, n); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals(FrameType.REQUEST_N, reusableFrame.getType()); - assertEquals(n, Frame.RequestN.requestN(reusableFrame)); - assertEquals(Frame.NULL_BYTEBUFFER, reusableFrame.getData()); - assertEquals(Frame.NULL_BYTEBUFFER, reusableFrame.getMetadata()); - } - - @Test - @Theory - public void shouldFormCorrectlyWithoutMetadataForLease(final int offset) - { - final int ttl = (int)TimeUnit.SECONDS.toMillis(8); - final int numberOfRequests = 16; - final Frame encodedFrame = Frame.Lease.from(ttl, numberOfRequests, Frame.NULL_BYTEBUFFER); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals(0, reusableFrame.getStreamId()); - assertEquals(FrameType.LEASE, reusableFrame.getType()); - assertEquals(ttl, Frame.Lease.ttl(reusableFrame)); - assertEquals(numberOfRequests, Frame.Lease.numberOfRequests(reusableFrame)); - assertEquals(Frame.NULL_BYTEBUFFER, reusableFrame.getData()); - assertEquals(Frame.NULL_BYTEBUFFER, reusableFrame.getMetadata()); - } - - @Test - @Theory - public void shouldFormCorrectlyWithMetadataForLease(final int offset) - { - final int ttl = (int)TimeUnit.SECONDS.toMillis(8); - final int numberOfRequests = 16; - final ByteBuffer leaseMetadata = TestUtil.byteBufferFromUtf8String("lease metadata"); - - final Frame encodedFrame = Frame.Lease.from(ttl, numberOfRequests, leaseMetadata); - TestUtil.copyFrame(reusableMutableDirectBuffer, offset, encodedFrame); - reusableFrame.wrap(reusableMutableDirectBuffer, offset); - - assertEquals(0, reusableFrame.getStreamId()); - assertEquals(FrameType.LEASE, reusableFrame.getType()); - assertEquals(ttl, Frame.Lease.ttl(reusableFrame)); - assertEquals(numberOfRequests, Frame.Lease.numberOfRequests(reusableFrame)); - assertEquals(Frame.NULL_BYTEBUFFER, reusableFrame.getData()); - assertEquals("lease metadata", TestUtil.byteToString(reusableFrame.getMetadata())); - } -} diff --git a/src/test/java/io/reactivesocket/LatchedCompletable.java b/src/test/java/io/reactivesocket/LatchedCompletable.java deleted file mode 100644 index e70df1df4..000000000 --- a/src/test/java/io/reactivesocket/LatchedCompletable.java +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -import io.reactivesocket.rx.Completable; - -public class LatchedCompletable implements Completable { - - final CountDownLatch latch; - - public LatchedCompletable(int count) { - this.latch = new CountDownLatch(count); - } - - @Override - public void success() { - latch.countDown(); - } - - @Override - public void error(Throwable e) { - System.err.println("Error waiting for Requester"); - e.printStackTrace(); - latch.countDown(); - } - - public void await() throws InterruptedException { - latch.await(); - } - - public boolean await(long timeout, TimeUnit unit) throws InterruptedException { - return latch.await(timeout, unit); - } - - -} diff --git a/src/test/java/io/reactivesocket/LeaseTest.java b/src/test/java/io/reactivesocket/LeaseTest.java deleted file mode 100644 index 65c36738e..000000000 --- a/src/test/java/io/reactivesocket/LeaseTest.java +++ /dev/null @@ -1,221 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import io.reactivesocket.internal.Responder; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.reactivestreams.Publisher; -import io.reactivex.subscribers.TestSubscriber; - -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; - -import static io.reactivesocket.TestUtil.byteToString; -import static io.reactivesocket.TestUtil.utf8EncodedPayload; -import static io.reactivesocket.ConnectionSetupPayload.HONOR_LEASE; - -import static org.junit.Assert.assertTrue; -import static io.reactivex.Observable.*; - -public class LeaseTest { - private TestConnection clientConnection; - private ReactiveSocket socketServer; - private ReactiveSocket socketClient; - private TestingLeaseGovernor leaseGovernor; - - private class TestingLeaseGovernor implements LeaseGovernor { - private volatile Responder responder; - private volatile long ttlExpiration; - private volatile int grantedTickets; - private CountDownLatch latch = new CountDownLatch(1); - - @Override - public synchronized void register(Responder responder) { - this.responder = responder; - latch.countDown(); - } - - @Override - public synchronized void unregister(Responder responder) { - this.responder = null; - } - - @Override - public synchronized boolean accept(Responder responder, Frame frame) { - boolean valid = grantedTickets > 0 - && ttlExpiration >= System.currentTimeMillis(); - grantedTickets--; - return valid; - } - - public synchronized void distribute(int ttlMs, int tickets) { - if (responder == null) { - throw new IllegalStateException("responder is null"); - } - ttlExpiration = System.currentTimeMillis() + ttlMs; - grantedTickets = tickets; - responder.sendLease(ttlMs, tickets); - } - } - - @Before - public void setup() throws InterruptedException { - TestConnection serverConnection = new TestConnection(); - clientConnection = new TestConnection(); - clientConnection.connectToServerConnection(serverConnection); - leaseGovernor = new TestingLeaseGovernor(); - - socketServer = DefaultReactiveSocket.fromServerConnection( - serverConnection, (setup, rs) -> new RequestHandler() { - - @Override - public Publisher handleRequestResponse(Payload payload) { - return just(utf8EncodedPayload("hello world", null)); - } - - @Override - public Publisher handleRequestStream(Payload payload) { - return - range(0, 100) - .map(i -> "hello world " + i) - .map(n -> utf8EncodedPayload(n, null) - ); - } - - @Override - public Publisher handleSubscription(Payload payload) { - return interval(1, TimeUnit.MICROSECONDS) - .map(i -> "subscription " + i) - .map(n -> utf8EncodedPayload(n, null)); - } - - @Override - public Publisher handleFireAndForget(Payload payload) { - return empty(); - } - - /** - * Use Payload.metadata for routing - */ - @Override - public Publisher handleChannel( - Payload initialPayload, Publisher inputs - ) { - return fromPublisher(inputs).map(p -> - utf8EncodedPayload(byteToString(p.getData()) + "_echo", null)); - } - - @Override - public Publisher handleMetadataPush(Payload payload) { - throw new IllegalStateException( - "TestingLeaseGovernor.handleMetadataPush is not implemented!"); - } - }, leaseGovernor, t -> {}); - - socketClient = DefaultReactiveSocket.fromClientConnection( - clientConnection, - ConnectionSetupPayload.create("UTF-8", "UTF-8", HONOR_LEASE) - ); - - // start both the server and client and monitor for errors - LatchedCompletable lc = new LatchedCompletable(2); - socketServer.start(lc); - socketClient.start(lc); - if(!lc.await(3000, TimeUnit.MILLISECONDS)) { - throw new RuntimeException("Timed out waiting for startup"); - } - } - - @After - public void shutdown() { - socketServer.shutdown(); - socketClient.shutdown(); - } - - @Test(timeout=2000) - public void testWriteWithoutLease() throws InterruptedException { - // initially client doesn't have any availability - assertTrue(socketClient.availability() == 0.0); - leaseGovernor.latch.await(); - assertTrue(socketClient.availability() == 0.0); - - // the first call will fail without a valid lease - Publisher response0 = socketClient.requestResponse( - TestUtil.utf8EncodedPayload("hello", null)); - TestSubscriber ts0 = new TestSubscriber<>();; - response0.subscribe(ts0); - ts0.awaitTerminalEvent(500, TimeUnit.MILLISECONDS); - - // send a Lease(10 sec, 1 message), and wait for the availability on the client side - leaseGovernor.distribute(10_000, 1); - awaitSocketAvailabilityChange(socketClient, 1.0, 10, TimeUnit.SECONDS); - - // the second call will succeed - Publisher response1 = socketClient.requestResponse( - TestUtil.utf8EncodedPayload("hello", null)); - TestSubscriber ts1 = new TestSubscriber<>();; - response1.subscribe(ts1); - ts1.awaitTerminalEvent(500, TimeUnit.MILLISECONDS); - ts1.assertNoErrors(); - ts1.assertValue(TestUtil.utf8EncodedPayload("hello world", null)); - - // the client consumed all its ticket, next call will fail - // (even though the window is still ok) - Publisher response2 = socketClient.requestResponse( - TestUtil.utf8EncodedPayload("hello", null)); - TestSubscriber ts2 = new TestSubscriber<>(); - response2.subscribe(ts2); - ts2.awaitTerminalEvent(500, TimeUnit.MILLISECONDS); - ts2.assertError(RuntimeException.class); - } - - @Test(timeout=2000) - public void testLeaseOverwrite() throws InterruptedException { - - assertTrue(socketClient.availability() == 0.0); - leaseGovernor.latch.await(); - assertTrue(socketClient.availability() == 0.0); - - leaseGovernor.distribute(10_000, 100); - awaitSocketAvailabilityChange(socketClient, 1.0, 10, TimeUnit.SECONDS); - - leaseGovernor.distribute(10_000, 0); - awaitSocketAvailabilityChange(socketClient, 0.0, 10, TimeUnit.SECONDS); - } - - private void awaitSocketAvailabilityChange( - ReactiveSocket socket, - double expected, - long timeout, - TimeUnit unit - ) throws InterruptedException { - long waitTimeMs = 1L; - long startTime = System.nanoTime(); - long timeoutNanos = TimeUnit.NANOSECONDS.convert(timeout, unit); - - while (socket.availability() != expected) { - Thread.sleep(waitTimeMs); - waitTimeMs = Math.min(waitTimeMs * 2, 1000L); - final long elapsedNanos = System.nanoTime() - startTime; - if (elapsedNanos > timeoutNanos) { - throw new IllegalStateException("Timeout while waiting for socket availability"); - } - } - } - -} diff --git a/src/test/java/io/reactivesocket/ReactiveSocketTest.java b/src/test/java/io/reactivesocket/ReactiveSocketTest.java deleted file mode 100644 index 28229a777..000000000 --- a/src/test/java/io/reactivesocket/ReactiveSocketTest.java +++ /dev/null @@ -1,514 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import io.reactivesocket.lease.FairLeaseGovernor; -import io.reactivex.disposables.Disposable; -import io.reactivex.observables.ConnectableObservable; -import io.reactivex.subscribers.TestSubscriber; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; -import org.junit.experimental.theories.DataPoints; -import org.junit.experimental.theories.Theories; -import org.junit.experimental.theories.Theory; -import org.junit.runner.RunWith; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; - -import static io.reactivesocket.ConnectionSetupPayload.HONOR_LEASE; -import static io.reactivesocket.ConnectionSetupPayload.NO_FLAGS; -import static io.reactivesocket.TestUtil.byteToString; -import static io.reactivesocket.TestUtil.utf8EncodedPayload; -import static io.reactivex.Observable.empty; -import static io.reactivex.Observable.error; -import static io.reactivex.Observable.fromPublisher; -import static io.reactivex.Observable.interval; -import static io.reactivex.Observable.just; -import static io.reactivex.Observable.range; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -@RunWith(Theories.class) -public class ReactiveSocketTest { - - private TestConnection clientConnection; - private ReactiveSocket socketServer; - private ReactiveSocket socketClient; - private AtomicBoolean helloSubscriptionRunning = new AtomicBoolean(false); - private AtomicReference lastFireAndForget = new AtomicReference<>(); - private AtomicReference lastMetadataPush = new AtomicReference<>(); - private AtomicReference lastServerError = new AtomicReference<>(); - private CountDownLatch lastServerErrorCountDown; - private CountDownLatch fireAndForgetOrMetadataPush; - - public static final @DataPoints int[] setupFlags = {NO_FLAGS, HONOR_LEASE}; - - @Before - public void setup() { - TestConnection serverConnection = new TestConnection(); - clientConnection = new TestConnection(); - clientConnection.connectToServerConnection(serverConnection); - fireAndForgetOrMetadataPush = new CountDownLatch(1); - lastServerErrorCountDown = new CountDownLatch(1); - - socketServer = DefaultReactiveSocket.fromServerConnection(serverConnection, (setup,rs) -> new RequestHandler() { - - @Override - public Publisher handleRequestResponse(Payload payload) { - String request = byteToString(payload.getData()); - System.out.println("********************************************************************************************** requestResponse: " + request); - if ("hello".equals(request)) { - System.out.println("********************************************************************************************** respond hello"); - return just(utf8EncodedPayload("hello world", null)); - } else { - return error(new RuntimeException("Not Found")); - } - } - - @Override - public Publisher handleRequestStream(Payload payload) { - String request = byteToString(payload.getData()); - if ("hello".equals(request)) { - return range(0, 100).map(i -> "hello world " + i).map(n -> utf8EncodedPayload(n, null)); - } else { - return error(new RuntimeException("Not Found")); - } - } - - @Override - public Publisher handleSubscription(Payload payload) { - String request = byteToString(payload.getData()); - if ("hello".equals(request)) { - return interval(1, TimeUnit.MICROSECONDS) - .onBackpressureDrop() - .doOnSubscribe(s -> helloSubscriptionRunning.set(true)) - .doOnCancel(() -> helloSubscriptionRunning.set(false)) - .map(i -> "subscription " + i) - .map(n -> utf8EncodedPayload(n, null)); - } else { - return error(new RuntimeException("Not Found")); - } - } - - @Override - public Publisher handleFireAndForget(Payload payload) { - try { - String request = byteToString(payload.getData()); - lastFireAndForget.set(request); - if ("log".equals(request)) { - return empty(); // success - } else if ("blowup".equals(request)) { - throw new RuntimeException("forced blowup to simulate handler error"); - } else { - lastFireAndForget.set("notFound"); - return error(new RuntimeException("Not Found")); - } - } finally { - fireAndForgetOrMetadataPush.countDown(); - } - } - - /** - * Use Payload.metadata for routing - */ - @Override - public Publisher handleChannel(Payload initialPayload, Publisher inputs) { - return new Publisher() { - @Override - public void subscribe(Subscriber subscriber) { - inputs.subscribe(new Subscriber() { - @Override - public void onSubscribe(Subscription s) { - subscriber.onSubscribe(s); - } - - @Override - public void onNext(Payload input) { - String metadata = byteToString(input.getMetadata()); - String data = byteToString(input.getData()); - if ("echo".equals(metadata)) { - subscriber.onNext(utf8EncodedPayload(data + "_echo", null)); - } else { - onError(new RuntimeException("Not Found")); - } - } - - @Override - public void onError(Throwable t) { - subscriber.onError(t); - } - - @Override - public void onComplete() { - subscriber.onComplete(); - } - }); - } - }; - } - - @Override - public Publisher handleMetadataPush(Payload payload) - { - try { - String request = byteToString(payload.getMetadata()); - lastMetadataPush.set(request); - if ("log".equals(request)) { - return empty(); // success - } else if ("blowup".equals(request)) { - throw new RuntimeException("forced blowup to simulate handler error"); - } else { - lastMetadataPush.set("notFound"); - return error(new RuntimeException("Not Found")); - } - } finally { - fireAndForgetOrMetadataPush.countDown(); - } - } - - private Publisher echoChannel(Publisher echo) { - return fromPublisher(echo).map(p -> { - return utf8EncodedPayload(byteToString(p.getData()) + "_echo", null); - }); - } - -// }, LeaseGovernor.UNLIMITED_LEASE_GOVERNOR, t -> { - }, new FairLeaseGovernor(100, 10L, TimeUnit.SECONDS), t -> { - t.printStackTrace(); - lastServerError.set(t); - lastServerErrorCountDown.countDown(); - }); - } - - @After - public void shutdown() { - socketServer.shutdown(); - socketClient.shutdown(); - } - - private void startSockets(int setupFlag, RequestHandler handler) throws InterruptedException { - if (setupFlag == NO_FLAGS) { - System.out.println("Reactivesocket configured with: NO_FLAGS"); - } else if (setupFlag == HONOR_LEASE) { - System.out.println("Reactivesocket configured with: HONOR_LEASE"); - } - socketClient = DefaultReactiveSocket.fromClientConnection( - clientConnection, - ConnectionSetupPayload.create("UTF-8", "UTF-8", setupFlag), - handler, - err -> err.printStackTrace() - ); - - // start both the server and client and monitor for errors - LatchedCompletable lc = new LatchedCompletable(2); - socketServer.start(lc); - socketClient.start(lc); - if(!lc.await(3000, TimeUnit.MILLISECONDS)) { - throw new RuntimeException("Timed out waiting for startup"); - } - - awaitSocketAvailability(socketClient, 50, TimeUnit.SECONDS); - } - - private void startSockets(int setupFlag) throws InterruptedException { - startSockets(setupFlag, null); - } - - private void awaitSocketAvailability(ReactiveSocket socket, long timeout, TimeUnit unit) { - long waitTimeMs = 1L; - long startTime = System.nanoTime(); - long timeoutNanos = TimeUnit.NANOSECONDS.convert(timeout, unit); - - while (socket.availability() == 0.0) { - try { - System.out.println("... waiting " + waitTimeMs + " ..."); - Thread.sleep(waitTimeMs); - waitTimeMs = Math.min(waitTimeMs * 2, 1000L); - final long elapsedNanos = System.nanoTime() - startTime; - if (elapsedNanos > timeoutNanos) { - throw new IllegalStateException("Timeout while waiting for socket availability"); - } - } catch (InterruptedException e) { - e.printStackTrace(); - } - } - assertTrue("client socket has positive avaibility", socket.availability() > 0.0); - } - - @Test(timeout=2000) - @Theory - public void testRequestResponse(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - // perform request/response - - Publisher response = socketClient.requestResponse(TestUtil.utf8EncodedPayload("hello", null)); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - ts.awaitTerminalEvent(); - ts.assertNoErrors(); - ts.assertValue(TestUtil.utf8EncodedPayload("hello world", null)); - } - - @Test(timeout=2000, expected=IllegalStateException.class) - public void testRequestResponsePremature() throws InterruptedException { - socketClient = DefaultReactiveSocket.fromClientConnection( - clientConnection, - ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), - err -> err.printStackTrace() - ); - - Publisher response = socketClient.requestResponse(TestUtil.utf8EncodedPayload("hello", null)); - } - - @Test(timeout=2000) - @Theory - public void testRequestStream(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - // perform request/stream - - Publisher response = socketClient.requestStream(TestUtil.utf8EncodedPayload("hello", null)); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - ts.awaitTerminalEvent(); - ts.assertNoErrors(); - assertEquals(100, ts.values().size()); - assertEquals("hello world 99", byteToString(ts.values().get(99).getData())); - } - - @Test(timeout=4000) - @Theory - public void testRequestSubscription(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - // perform request/subscription - - Publisher response = socketClient.requestSubscription(TestUtil.utf8EncodedPayload("hello", null)); - TestSubscriber ts = new TestSubscriber<>(); - TestSubscriber ts2 = new TestSubscriber<>(); - ConnectableObservable published = fromPublisher(response).publish(); - published.take(10).subscribe(ts); - published.subscribe(ts2); - Disposable subscription = published.connect(); - - // ts completed due to take - ts.awaitTerminalEvent(); - ts.assertNoErrors(); - ts.assertComplete(); - - // ts2 should never complete - ts2.assertNoErrors(); - ts2.assertNotTerminated(); - - // assert it is running still - assertTrue(helloSubscriptionRunning.get()); - - // shut down the work - subscription.dispose(); - - // wait for up to 2 seconds for the async CANCEL to occur (it sends a message up) - for (int i = 0; i < 20; i++) { - if (!helloSubscriptionRunning.get()) { - break; - } - try { - Thread.sleep(100); - } catch (InterruptedException e) { - } - } - // and then stopped after unsubscribing - assertFalse(helloSubscriptionRunning.get()); - - assertEquals(10, ts.values().size()); - assertEquals("subscription 9", byteToString(ts.values().get(9).getData())); - } - - @Test(timeout=2000) - @Theory - public void testFireAndForgetSuccess(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - - // perform request/response - - Publisher response = socketClient.fireAndForget(TestUtil.utf8EncodedPayload("log", null)); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - // these only test client side since this is fireAndForgetOrMetadataPush - ts.awaitTerminalEvent(); - ts.assertNoErrors(); - ts.assertComplete(); - // this waits for server-side - fireAndForgetOrMetadataPush.await(); - assertEquals("log", lastFireAndForget.get()); - } - - @Test(timeout=2000) - @Theory - public void testFireAndForgetServerSideErrorNotFound(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - // perform request/response - - Publisher response = socketClient.fireAndForget(TestUtil.utf8EncodedPayload("unknown", null)); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - // these only test client side since this is fireAndForgetOrMetadataPush - ts.awaitTerminalEvent(); - ts.assertNoErrors();// client-side won't see an error - ts.assertComplete(); - // this waits for server-side - fireAndForgetOrMetadataPush.await(); - assertEquals("notFound", lastFireAndForget.get()); - } - - @Test(timeout=2000) - @Theory - public void testFireAndForgetServerSideErrorHandlerBlowup(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - // perform request/response - - Publisher response = socketClient.fireAndForget(TestUtil.utf8EncodedPayload("blowup", null)); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - // these only test client side since this is fireAndForgetOrMetadataPush - ts.awaitTerminalEvent(); - ts.assertNoErrors();// client-side won't see an error - ts.assertComplete(); - // this waits for server-side - fireAndForgetOrMetadataPush.await(); - assertEquals("blowup", lastFireAndForget.get()); - lastServerErrorCountDown.await(); - assertEquals("forced blowup to simulate handler error", lastServerError.get().getCause().getMessage()); - } - - @Test(timeout=2000) - @Theory - public void testRequestChannelEcho(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - - Publisher inputs = just( - TestUtil.utf8EncodedPayload("1", "echo"), - TestUtil.utf8EncodedPayload("2", "echo") - ); - Publisher outputs = socketClient.requestChannel(inputs); - TestSubscriber ts = new TestSubscriber<>(); - outputs.subscribe(ts); - ts.awaitTerminalEvent(); - ts.assertNoErrors(); - assertEquals(2, ts.values().size()); - assertEquals("1_echo", byteToString(ts.values().get(0).getData())); - assertEquals("2_echo", byteToString(ts.values().get(1).getData())); - } - - @Test(timeout=2000) - @Theory - public void testRequestChannelNotFound(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - - Publisher requestStream = just(TestUtil.utf8EncodedPayload(null, "someChannel")); - Publisher response = socketClient.requestChannel(requestStream); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - ts.awaitTerminalEvent(); - ts.assertTerminated(); - ts.assertNotComplete(); - ts.assertNoValues(); - ts.assertErrorMessage("Not Found"); - } - - @Test(timeout=2000) - @Theory - public void testMetadataPushSuccess(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - - // perform request/response - - Publisher response = socketClient.metadataPush(TestUtil.utf8EncodedPayload(null, "log")); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - ts.awaitTerminalEvent(); - ts.assertNoErrors(); - ts.assertComplete(); - // this waits for server-side - fireAndForgetOrMetadataPush.await(); - assertEquals("log", lastMetadataPush.get()); - } - - @Test(timeout=2000) - @Theory - public void testMetadataPushServerSideErrorNotFound(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - // perform request/response - - Publisher response = socketClient.metadataPush(TestUtil.utf8EncodedPayload(null, "unknown")); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - ts.awaitTerminalEvent(); - ts.assertNoErrors();// client-side won't see an error - ts.assertComplete(); - // this waits for server-side - fireAndForgetOrMetadataPush.await(); - assertEquals("notFound", lastMetadataPush.get()); - } - - @Test(timeout=2000) - @Theory - public void testMetadataPushServerSideErrorHandlerBlowup(int setupFlag) throws InterruptedException { - startSockets(setupFlag); - // perform request/response - - Publisher response = socketClient.metadataPush(TestUtil.utf8EncodedPayload(null, "blowup")); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - ts.awaitTerminalEvent(); - ts.assertNoErrors();// client-side won't see an error - ts.assertComplete(); - // this waits for server-side - fireAndForgetOrMetadataPush.await(); - assertEquals("blowup", lastMetadataPush.get()); - lastServerErrorCountDown.await(); - assertEquals("forced blowup to simulate handler error", lastServerError.get().getCause().getMessage()); - } - - @Test(timeout=2000) - @Theory - public void testServerRequestResponse(int setupFlag) throws InterruptedException { - startSockets(setupFlag, new RequestHandler.Builder() - .withRequestResponse(payload -> { - return just(utf8EncodedPayload("hello world from client", null)); - }).build()); - - CountDownLatch latch = new CountDownLatch(1); - socketServer.onRequestReady(err -> { - latch.countDown(); - }); - latch.await(); - - Publisher response = socketServer.requestResponse(TestUtil.utf8EncodedPayload("hello", null)); - TestSubscriber ts = new TestSubscriber<>(); - response.subscribe(ts); - ts.awaitTerminalEvent(); - ts.assertNoErrors(); - ts.assertValue(TestUtil.utf8EncodedPayload("hello world from client", null)); - } - - -} diff --git a/src/test/java/io/reactivesocket/SerializedEventBus.java b/src/test/java/io/reactivesocket/SerializedEventBus.java deleted file mode 100644 index 01018b9ee..000000000 --- a/src/test/java/io/reactivesocket/SerializedEventBus.java +++ /dev/null @@ -1,80 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import java.util.concurrent.CopyOnWriteArrayList; -import java.util.function.Consumer; - -import io.reactivesocket.rx.Observer; -import io.reactivex.subjects.PublishSubject; -import io.reactivex.subjects.Subject; - -/** - * Multicast eventbus that serializes incoming events. - */ -public class SerializedEventBus { - - private final CopyOnWriteArrayList> os = new CopyOnWriteArrayList<>(); - private Subject s; - - public SerializedEventBus() { - s = PublishSubject.create().toSerialized(); - s.subscribe(f-> { - for (Observer o : os) { - o.onNext(f); - } - }); - } - - public void send(Frame f) { - s.onNext(f); - } - - public void add(Observer o) { - os.add(o); - } - - public void add(Consumer f) { - add(new Observer() { - - @Override - public void onNext(Frame t) { - f.accept(t); - } - - @Override - public void onError(Throwable e) { - - } - - @Override - public void onComplete() { - - } - - @Override - public void onSubscribe(io.reactivesocket.rx.Disposable d) { - // TODO Auto-generated method stub - - } - - }); - } - - public void remove(Observer o) { - os.remove(o); - } -} \ No newline at end of file diff --git a/src/test/java/io/reactivesocket/TestConnection.java b/src/test/java/io/reactivesocket/TestConnection.java deleted file mode 100644 index 9f1c4e7b5..000000000 --- a/src/test/java/io/reactivesocket/TestConnection.java +++ /dev/null @@ -1,109 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import static io.reactivex.Observable.*; - -import java.io.IOException; - -import org.reactivestreams.Publisher; - -import io.reactivesocket.rx.Completable; -import io.reactivesocket.rx.Observer; -import io.reactivex.Observable; -import io.reactivex.Scheduler.Worker; -import io.reactivex.schedulers.Schedulers; - -public class TestConnection implements DuplexConnection { - - public final SerializedEventBus toInput = new SerializedEventBus(); - public final SerializedEventBus write = new SerializedEventBus(); - - @Override - public void addOutput(Publisher o, Completable callback) { - fromPublisher(o).flatMap(m -> { - // no backpressure on a Subject so just firehosing for this test - write.send(m); - return Observable. empty(); - }).subscribe(v -> { - } , callback::error, callback::success); - } - - @Override - public void addOutput(Frame f, Completable callback) { - write.send(f); - callback.success(); - } - - @Override - public io.reactivesocket.rx.Observable getInput() { - return new io.reactivesocket.rx.Observable() { - - @Override - public void subscribe(Observer o) { - toInput.add(o); - // we are okay with the race of sending data and cancelling ... since this is "hot" by definition and unsubscribing is a race. - o.onSubscribe(new io.reactivesocket.rx.Disposable() { - - @Override - public void dispose() { - toInput.remove(o); - } - - }); - } - - }; - } - - public void connectToServerConnection(TestConnection serverConnection) { - connectToServerConnection(serverConnection, true); - } - - Worker clientThread = Schedulers.newThread().createWorker(); - Worker serverThread = Schedulers.newThread().createWorker(); - - public void connectToServerConnection(TestConnection serverConnection, boolean log) { - if (log) { - serverConnection.write.add(n -> System.out.println("SERVER ==> Writes from server->client: " + n + " Written from " + Thread.currentThread())); - serverConnection.toInput.add(n -> System.out.println("SERVER <== Input from client->server: " + n + " Read on " + Thread.currentThread())); - write.add(n -> System.out.println("CLIENT ==> Writes from client->server: " + n + " Written from " + Thread.currentThread())); - toInput.add(n -> System.out.println("CLIENT <== Input from server->client: " + n + " Read on " + Thread.currentThread())); - } - - // client to server - write.add(f -> { -// serverConnection.toInput.send(f); - serverThread.schedule(() -> { - serverConnection.toInput.send(f); - }); - }); - // server to client - serverConnection.write.add(f -> { -// toInput.send(f); - clientThread.schedule(() -> { - toInput.send(f); - }); - }); - } - - @Override - public void close() throws IOException { - clientThread.dispose(); - serverThread.dispose(); - } - -} \ No newline at end of file diff --git a/src/test/java/io/reactivesocket/TestConnectionWithControlledRequestN.java b/src/test/java/io/reactivesocket/TestConnectionWithControlledRequestN.java deleted file mode 100644 index fc4f3595f..000000000 --- a/src/test/java/io/reactivesocket/TestConnectionWithControlledRequestN.java +++ /dev/null @@ -1,124 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.atomic.AtomicLong; - -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import io.reactivesocket.rx.Completable; - -/** - * Connection that by defaults only calls request(1) on a Publisher to addOutput. Any further must be done via requestMore(n) - *

- * NOTE: This should ONLY be used for 1 test at a time as it maintains state. Call close() when done. - */ -public class TestConnectionWithControlledRequestN extends TestConnection { - - public List subscriptions = Collections.synchronizedList(new ArrayList()); - public AtomicLong emitted = new AtomicLong(); - public AtomicLong requested = new AtomicLong(); - - @Override - public void addOutput(Publisher o, Completable callback) { - System.out.println("TestConnectionWithControlledRequestN => addOutput"); - o.subscribe(new Subscriber() { - - volatile Subscription _s = null; - public AtomicLong sEmitted = new AtomicLong(); - - @Override - public void onSubscribe(Subscription s) { - _s = new Subscription() { - - @Override - public void request(long n) { - requested.addAndGet(n); - s.request(n); - } - - @Override - public void cancel() { - subscriptions.remove(_s); - s.cancel(); - } - - }; - subscriptions.add(_s); - _s.request(1); - } - - @Override - public void onNext(Frame t) { - emitted.incrementAndGet(); - sEmitted.incrementAndGet(); - write.send(t); - } - - @Override - public void onError(Throwable t) { - subscriptions.remove(_s); - callback.error(t); - } - - @Override - public void onComplete() { - System.out.println("TestConnectionWithControlledRequestN => complete, emitted: " + sEmitted.get()); - subscriptions.remove(_s); - callback.success(); - } - - }); - } - - @Override - public void addOutput(Frame f, Completable callback) { - emitted.incrementAndGet(); - write.send(f); - callback.success(); - } - - public boolean awaitSubscription(int timeInMillis) { - long start = System.currentTimeMillis(); - while (subscriptions.size() == 0) { - Thread.yield(); - if(System.currentTimeMillis() - start > timeInMillis) { - return false; - } - } - return true; - } - - /** - * Request more against the first subscription. This will ONLY request against the oldest Subscription, one at a time. - *

- * When one completes, it does NOT propagate request(n) to the next. Thus, this assumes unit tests where you know what you are doing with request(n). - * - * @param n - */ - public void requestMore(int n) { - if (subscriptions.size() == 0) { - throw new IllegalStateException("no subscriptions to request from"); - } - subscriptions.get(0).request(n); - } - -} \ No newline at end of file diff --git a/src/test/java/io/reactivesocket/TestFlowControlRequestN.java b/src/test/java/io/reactivesocket/TestFlowControlRequestN.java deleted file mode 100644 index 1d877ea03..000000000 --- a/src/test/java/io/reactivesocket/TestFlowControlRequestN.java +++ /dev/null @@ -1,463 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; -import org.reactivestreams.Publisher; -import org.reactivestreams.Subscriber; -import org.reactivestreams.Subscription; - -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; - -import static io.reactivesocket.ConnectionSetupPayload.NO_FLAGS; -import static io.reactivesocket.TestUtil.byteToString; -import static io.reactivesocket.TestUtil.utf8EncodedPayload; -import static io.reactivex.Observable.error; -import static io.reactivex.Observable.fromPublisher; -import static io.reactivex.Observable.range; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -public class TestFlowControlRequestN { - - @Test(timeout=2000) - public void testRequestStream_batches() throws InterruptedException { - ControlledSubscriber s = new ControlledSubscriber(); - socketClient.requestStream(utf8EncodedPayload("100", null)).subscribe(s); - assertEquals(0, s.received.get()); - assertEquals(0, emitted.get()); - s.subscription.request(10); - waitForAsyncValue(s.received, 10); - assertEquals(10, s.received.get()); - assertEquals(10, emitted.get()); - s.subscription.request(50); - waitForAsyncValue(s.received, 60); - assertEquals(60, s.received.get()); - assertEquals(60, emitted.get()); - s.subscription.request(100); - waitForAsyncValue(s.received, 100); - assertEquals(100, s.received.get()); - s.terminated.await(); - assertEquals(100, emitted.get()); - - assertTrue(s.completed.get()); - } - - @Test(timeout=3000) - public void testRequestStream_fastProducer_slowConsumer_maxValueRequest() throws InterruptedException { - CountDownLatch latch = new CountDownLatch(1); - CountDownLatch cancelled = new CountDownLatch(1); - AtomicInteger received = new AtomicInteger(); - socketClient.requestStream(utf8EncodedPayload("10000", null)).subscribe(new Subscriber() { - - Subscription subscription; - - @Override - public void onSubscribe(Subscription s) { - subscription = s; - s.request(Long.MAX_VALUE); // act like a synchronous consumer that doesn't need backpressure - } - - @Override - public void onNext(Payload t) { - int r = received.incrementAndGet(); - System.out.println("onNext " + r); - if (r == 10) - { - // be a "slow" consumer - try { - Thread.sleep(1000); - } catch (InterruptedException e) { } - System.out.println("Emitted on server: " + emitted.get() - + " Received on client: " + received); - } - else if (r == 200) { - System.out.println("Cancel"); - // cancel - subscription.cancel(); - cancelled.countDown(); - onComplete(); - } - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - latch.countDown(); - } - - @Override - public void onComplete() { - System.out.println("complete"); - latch.countDown(); - } - - }); - - System.out.println("waiting"); - latch.await(3000, TimeUnit.MILLISECONDS); - cancelled.await(3000, TimeUnit.MILLISECONDS); - assertEquals(200, received.get()); - if(emitted.get() > 1024) { - fail("Emitted more than expected"); - } - } - - @Test(timeout=2000) - public void testRequestSubscription_batches() throws InterruptedException { - ControlledSubscriber s = new ControlledSubscriber(); - socketClient.requestSubscription(utf8EncodedPayload("", null)).subscribe(s); - assertEquals(0, s.received.get()); - assertEquals(0, emitted.get()); - s.subscription.request(10); - waitForAsyncValue(s.received, 10); - assertEquals(10, s.received.get()); - assertEquals(10, emitted.get()); - s.subscription.request(50); - waitForAsyncValue(s.received, 60); - assertEquals(60, s.received.get()); - assertEquals(60, emitted.get()); - s.subscription.request(100); - waitForAsyncValue(s.received, 160); - assertEquals(160, s.received.get()); - s.subscription.cancel(); - Thread.sleep(100); - assertEquals(160, emitted.get()); - } - - /** - * Test that downstream is governed by request(n) - * @throws InterruptedException - */ - @Test(timeout=2000) - public void testRequestChannel_batches_downstream() throws InterruptedException { - ControlledSubscriber s = new ControlledSubscriber(); - socketClient.requestChannel( - range(1, 10).map(i -> utf8EncodedPayload(String.valueOf(i), "1000")) - ).subscribe(s); - - // if flatMap is being used, then each of the 10 streams will emit at least 128 (default) - - assertEquals(0, s.received.get()); - assertEquals(0, emitted.get()); - s.subscription.request(10); - waitForAsyncValue(s.received, 10); - assertEquals(10, s.received.get()); - s.subscription.request(300); - waitForAsyncValue(s.received, 310); - assertEquals(310, s.received.get()); - s.subscription.request(2000); - waitForAsyncValue(s.received, 2310); - assertEquals(2310, s.received.get()); - s.subscription.cancel(); - Thread.sleep(100); - assertEquals(2310, s.received.get()); - // emitted with `flatMap` does internal buffering, so it won't be exactly 2310, - // but it should be far less than the potential 10,000 - if(emitted.get() > 4096) { - fail("Emitted " + emitted.get()); - } - } - - /** - * Test that the upstream is governed by request(n) - * @throws InterruptedException - */ - @Test(timeout=2000) - public void testRequestChannel_batches_upstream_echo() throws InterruptedException { - ControlledSubscriber s = new ControlledSubscriber(); - AtomicInteger emittedClient = new AtomicInteger(); - socketClient.requestChannel( - range(1, 10000) - .doOnNext(n -> emittedClient.incrementAndGet()) - .doOnRequest(r -> System.out.println("CLIENT REQUESTS requestN: " + r)) - .map(i -> { - // metadata to route us to the echo behavior (only actually need - // this in the first payload) - return utf8EncodedPayload(String.valueOf(i), "echo"); - })).subscribe(s); - - assertEquals(0, s.received.get()); - assertEquals(0, emitted.get()); - assertEquals(0, emittedClient.get()); - s.subscription.request(10); - waitForAsyncValue(s.received, 10); - assertEquals(10, emittedClient.get()); - assertEquals(10, s.received.get()); - s.subscription.request(200); - waitForAsyncValue(s.received, 210); - assertEquals(210, emittedClient.get()); - assertEquals(210, s.received.get()); - Thread.sleep(100); - assertFalse(s.error.get()); - - System.out.println(">>> Client sent " + emittedClient.get() - + " requests and received " + s.received.get() + " responses"); - } - - /** - * Test that the upstream is governed by request(n) - * @throws InterruptedException - */ - @Test(timeout=2000) - public void testRequestChannel_batches_upstream_decoupled() throws InterruptedException { - ControlledSubscriber s = new ControlledSubscriber(); - AtomicInteger emittedClient = new AtomicInteger(); - socketClient.requestChannel( - range(1, 10000) - .doOnNext(n -> emittedClient.incrementAndGet()) - .doOnRequest(r -> System.out.println("CLIENT REQUESTS requestN: " + r)) - .map(i -> { - // metadata to route us to the echo behavior (only actually need this - // in the first payload) - return utf8EncodedPayload(String.valueOf(i), "decoupled"); - })).subscribe(s); - - assertEquals(0, s.received.get()); - assertEquals(0, emitted.get()); - assertEquals(0, emittedClient.get()); - s.subscription.request(10); - waitForAsyncValue(s.received, 10); - assertEquals(10, s.received.get()); - s.subscription.request(200); - waitForAsyncValue(s.received, 210); - assertEquals(210, s.received.get()); - Thread.sleep(100); - assertFalse(s.error.get()); - // the responder side of 'decoupled' is limiting to 300 (batches of 50 and 250) - // so we should only emit 300 of the possible 10000 - assertEquals(300, emittedClient.get()); - - System.out.println(">>> Client sent " + emittedClient.get() - + " requests and received " + s.received.get() + " responses"); - } - - private void waitForAsyncValue(AtomicInteger value, int n) throws InterruptedException { - while (value.get() != n && !Thread.interrupted()) { - Thread.sleep(1); - } - } - - private static class ControlledSubscriber implements Subscriber { - - AtomicInteger received = new AtomicInteger(); - Subscription subscription; - CountDownLatch terminated = new CountDownLatch(1); - AtomicBoolean completed = new AtomicBoolean(false); - AtomicBoolean error = new AtomicBoolean(false); - - @Override - public void onSubscribe(Subscription s) { - this.subscription = s; - } - - @Override - public void onNext(Payload t) { - received.incrementAndGet(); - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - error.set(true); - terminated.countDown(); - } - - @Override - public void onComplete() { - completed.set(true); - terminated.countDown(); - } - - } - - private static TestConnection serverConnection; - private static TestConnection clientConnection; - private static ReactiveSocket socketServer; - private static ReactiveSocket socketClient; - private static AtomicInteger emitted = new AtomicInteger(); - private static AtomicInteger numRequests = new AtomicInteger(); - private static AtomicLong requested = new AtomicLong(); - - @Before - public void init() { - emitted.set(0); - requested.set(0); - numRequests.set(0); - } - - @BeforeClass - public static void setup() throws InterruptedException { - serverConnection = new TestConnection(); - clientConnection = new TestConnection(); - clientConnection.connectToServerConnection(serverConnection, false); - - - socketServer = DefaultReactiveSocket.fromServerConnection(serverConnection, (setup,rs) -> new RequestHandler() { - - @Override - public Publisher handleRequestStream(Payload payload) { - String request = byteToString(payload.getData()); - System.out.println("responder received requestStream: " + request); - return range(0, Integer.parseInt(request)) - .doOnRequest(n -> System.out.println("requested in responder: " + n)) - .doOnRequest(r -> requested.addAndGet(r)) - .doOnRequest(r -> numRequests.incrementAndGet()) - .doOnNext(i -> emitted.incrementAndGet()) - .map(i -> utf8EncodedPayload(String.valueOf(i), null)); - } - - @Override - public Publisher handleSubscription(Payload payload) { - return range(0, Integer.MAX_VALUE) - .doOnRequest(n -> System.out.println("requested in responder: " + n)) - .doOnRequest(r -> requested.addAndGet(r)) - .doOnRequest(r -> numRequests.incrementAndGet()) - .doOnNext(i -> emitted.incrementAndGet()) - .map(i -> utf8EncodedPayload(String.valueOf(i), null)); - } - - /** - * Use Payload.metadata for routing - */ - @Override - public Publisher handleChannel(Payload initialPayload, Publisher payloads) { - String requestMetadata = byteToString(initialPayload.getMetadata()); - System.out.println("responder received requestChannel: " + requestMetadata); - - if(requestMetadata.equals("echo")) { - // TODO I want this to be concatMap instead of flatMap but apparently concatMap has a bug - return fromPublisher(payloads).map(payload -> { - String payloadData = byteToString(payload.getData()); - return utf8EncodedPayload(String.valueOf(payloadData) + "_echo", null); - }).doOnRequest(n -> System.out.println(">>> requested in echo responder: " + n)) - .doOnRequest(r -> requested.addAndGet(r)) - .doOnRequest(r -> numRequests.incrementAndGet()) - .doOnError(t -> System.out.println("Error in 'echo' handler: " + t.getMessage())) - .doOnNext(i -> emitted.incrementAndGet()); - } else if (requestMetadata.equals("decoupled")) { - /* - * Consume 300 from request and then stop requesting more (but no cancel from responder side) - */ - fromPublisher(payloads).doOnNext(payload -> { - String payloadData = byteToString(payload.getData()); - System.out.println("DECOUPLED side-effect of request: " + payloadData); - }).subscribe(new Subscriber() { - - int count=0; - Subscription s; - - @Override - public void onError(Throwable e) { - - } - - @Override - public void onNext(Payload t) { - count++; - if(count == 50) { - s.request(250); - } - } - - @Override - public void onSubscribe(Subscription s) { - this.s = s; - // start with 50 - s.request(50); - } - - @Override - public void onComplete() { - // TODO Auto-generated method stub - - } - - - }); - - return range(1, 1000) - .doOnNext(n -> System.out.println("RESPONDER sending value: " + n)) - .map(i -> { - return utf8EncodedPayload(String.valueOf(i) + "_decoupled", null); - }) - .doOnRequest(n -> System.out.println(">>> requested in decoupled responder: " + n)) - .doOnRequest(r -> requested.addAndGet(r)) - .doOnRequest(r -> numRequests.incrementAndGet()) - .doOnError(t -> System.out.println("Error in 'decoupled' handler: " + t.getMessage())) - .doOnNext(i -> emitted.incrementAndGet()); - } else { - // TODO I want this to be concatMap instead of flatMap but apparently concatMap has a bug - return fromPublisher(payloads).flatMap(payload -> { - String payloadData = byteToString(payload.getData()); - System.out.println("responder handleChannel received payload: " + payloadData); - return range(0, Integer.parseInt(requestMetadata)) - .doOnRequest(n -> System.out.println("requested in responder [" + payloadData + "]: " + n)) - .doOnRequest(r -> requested.addAndGet(r)) - .doOnRequest(r -> numRequests.incrementAndGet()) - .doOnNext(i -> emitted.incrementAndGet()) - .map(i -> utf8EncodedPayload(String.valueOf(i), null)); - }).doOnRequest(n -> System.out.println(">>> response stream request(n) in responder: " + n)); - } - } - - @Override - public Publisher handleFireAndForget(Payload payload) { - return error(new RuntimeException("Not Found")); - } - - @Override - public Publisher handleRequestResponse(Payload payload) { - return error(new RuntimeException("Not Found")); - } - - @Override - public Publisher handleMetadataPush(Payload payload) - { - return error(new RuntimeException("Not Found")); - } - }, LeaseGovernor.UNLIMITED_LEASE_GOVERNOR, Throwable::printStackTrace); - - socketClient = DefaultReactiveSocket.fromClientConnection( - clientConnection, - ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), - Throwable::printStackTrace - ); - - // start both the server and client and monitor for errors - LatchedCompletable lc = new LatchedCompletable(2); - socketServer.start(lc); - socketClient.start(lc); - if(!lc.await(3000, TimeUnit.MILLISECONDS)) { - throw new RuntimeException("Timed out waiting for startup"); - } - } - - @AfterClass - public static void shutdown() { - socketServer.shutdown(); - socketClient.shutdown(); - } -} diff --git a/src/test/java/io/reactivesocket/TestTransportRequestN.java b/src/test/java/io/reactivesocket/TestTransportRequestN.java deleted file mode 100644 index eb27ebc2c..000000000 --- a/src/test/java/io/reactivesocket/TestTransportRequestN.java +++ /dev/null @@ -1,243 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import io.reactivesocket.lease.FairLeaseGovernor; -import io.reactivex.subscribers.TestSubscriber; -import org.junit.After; -import org.junit.Ignore; -import org.junit.Test; -import org.reactivestreams.Publisher; - -import java.io.IOException; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; - -import static io.reactivesocket.TestUtil.utf8EncodedPayload; -import static io.reactivex.Observable.error; -import static io.reactivex.Observable.fromPublisher; -import static io.reactivex.Observable.interval; -import static io.reactivex.Observable.just; -import static io.reactivex.Observable.range; -import static org.junit.Assert.fail; - -/** - * Ensure that request(n) from DuplexConnection "transport" layer is respected. - * - */ -public class TestTransportRequestN { - - @Test(timeout = 3000) - public void testRequestStreamWithNFromTransport() throws InterruptedException { - clientConnection = new TestConnectionWithControlledRequestN(); - serverConnection = new TestConnectionWithControlledRequestN(); - setup(clientConnection, serverConnection); - - TestSubscriber ts = new TestSubscriber<>(); - fromPublisher(socketClient.requestStream(utf8EncodedPayload("", null))) - .take(150) - .subscribe(ts); - - // wait for server to add output - if (!serverConnection.awaitSubscription(1000)) { - fail("Did not receive subscription"); - } - // now request some data, but less than it is expected to output - serverConnection.requestMore(10); - - // since we are async, give time for emission to occur - Thread.sleep(500); - - // we should not have received more than 11 (10 + default 1 that is requested) - - if (ts.valueCount() > 11) { - fail("Received more (" + ts.valueCount() + ") than transport requested (11)"); - } - - ts.cancel(); - - // since we are async, give time for emission to occur - Thread.sleep(500); - - if (serverConnection.emitted.get() > serverConnection.requested.get()) { - fail("Emitted more (" + serverConnection.emitted.get() + ") than transport requested (" + serverConnection.requested.get() + ")"); - } - } - - @Test(timeout = 3000) - public void testRequestChannelDownstreamWithNFromTransport() throws InterruptedException { - clientConnection = new TestConnectionWithControlledRequestN(); - serverConnection = new TestConnectionWithControlledRequestN(); - setup(clientConnection, serverConnection); - - TestSubscriber ts = new TestSubscriber<>(); - fromPublisher(socketClient.requestChannel(just(utf8EncodedPayload("", null)))) - .take(150) - .subscribe(ts); - - // wait for server to add output - if (!serverConnection.awaitSubscription(1000)) { - fail("Did not receive subscription"); - } - // now request some data, but less than it is expected to output - serverConnection.requestMore(10); - - // since we are async, give time for emission to occur - Thread.sleep(500); - - // we should not have received more than 11 (10 + default 1 that is requested) - - if (ts.valueCount() > 11) { - fail("Received more (" + ts.valueCount() + ") than transport requested (11)"); - } - - ts.cancel(); - - // since we are async, give time for emission to occur - Thread.sleep(500); - - if (serverConnection.emitted.get() > serverConnection.requested.get()) { - fail("Emitted more (" + serverConnection.emitted.get() + ") than transport requested (" + serverConnection.requested.get() + ")"); - } - } - - // TODO come back after some other work (Ben) - @Ignore - @Test(timeout = 3000) - public void testRequestChannelUpstreamWithNFromTransport() throws InterruptedException { - clientConnection = new TestConnectionWithControlledRequestN(); - serverConnection = new TestConnectionWithControlledRequestN(); - setup(clientConnection, serverConnection); - - TestSubscriber ts = new TestSubscriber<>(); - fromPublisher(socketClient.requestChannel(range(0, 1000).map(i -> utf8EncodedPayload("" + i, null)))) - .take(10) - .subscribe(ts); - - // wait for server to add output - if (!serverConnection.awaitSubscription(1000)) { - fail("Did not receive subscription"); - } - // now request some data, but less than it is expected to output - serverConnection.requestMore(10); -// clientConnection.requestMore(2); - - // since we are async, give time for emission to occur - Thread.sleep(500); - - // we should not have received more than 11 (10 + default 1 that is requested) - - if (ts.valueCount() > 11) { - fail("Received more (" + ts.valueCount() + ") than transport requested (11)"); - } - - ts.cancel(); - - // since we are async, give time for emission to occur - Thread.sleep(500); - - if (serverConnection.emitted.get() > serverConnection.requested.get()) { - fail("Server Emitted more (" + serverConnection.emitted.get() + ") than transport requested (" + serverConnection.requested.get() + ")"); - } - - if (clientConnection.emitted.get() > clientConnection.requested.get()) { - fail("Client Emitted more (" + clientConnection.emitted.get() + ") than transport requested (" + clientConnection.requested.get() + ")"); - } - } - - private TestConnectionWithControlledRequestN serverConnection; - private TestConnectionWithControlledRequestN clientConnection; - private ReactiveSocket socketServer; - private ReactiveSocket socketClient; - private AtomicBoolean helloSubscriptionRunning = new AtomicBoolean(false); - private AtomicReference lastServerError = new AtomicReference<>(); - private CountDownLatch lastServerErrorCountDown; - - public void setup(TestConnectionWithControlledRequestN clientConnection, TestConnectionWithControlledRequestN serverConnection) throws InterruptedException { - clientConnection.connectToServerConnection(serverConnection, false); - lastServerErrorCountDown = new CountDownLatch(1); - - socketServer = DefaultReactiveSocket.fromServerConnection(serverConnection, (setup,rs) -> new RequestHandler() { - - @Override - public Publisher handleRequestResponse(Payload payload) { - return just(utf8EncodedPayload("request_response", null)); - } - - @Override - public Publisher handleRequestStream(Payload payload) { - return range(0, 10000).map(i -> "stream_response_" + i).map(n -> utf8EncodedPayload(n, null)); - } - - @Override - public Publisher handleSubscription(Payload payload) { - return interval(1, TimeUnit.MILLISECONDS) - .onBackpressureDrop() - .doOnSubscribe(s -> helloSubscriptionRunning.set(true)) - .doOnCancel(() -> helloSubscriptionRunning.set(false)) - .map(i -> "subscription " + i) - .map(n -> utf8EncodedPayload(n, null)); - } - - @Override - public Publisher handleFireAndForget(Payload payload) { - return error(new RuntimeException("Not Found")); - } - - /** - * Use Payload.metadata for routing - */ - @Override - public Publisher handleChannel(Payload initialPayload, Publisher inputs) { - return range(0, 10000).map(i -> "channel_response_" + i).map(n -> utf8EncodedPayload(n, null)); - } - - @Override - public Publisher handleMetadataPush(Payload payload) { - return error(new RuntimeException("Not Found")); - } - - }, new FairLeaseGovernor(100, 10L, TimeUnit.SECONDS), t -> { - t.printStackTrace(); - lastServerError.set(t); - lastServerErrorCountDown.countDown(); - }); - - socketClient = DefaultReactiveSocket.fromClientConnection( - clientConnection, - ConnectionSetupPayload.create("UTF-8", "UTF-8", ConnectionSetupPayload.NO_FLAGS), - err -> err.printStackTrace()); - - // start both the server and client and monitor for errors - socketServer.startAndWait(); - socketClient.startAndWait(); - } - - @After - public void shutdown() { - socketServer.shutdown(); - socketClient.shutdown(); - try { - clientConnection.close(); - serverConnection.close(); - } catch (IOException e) { - e.printStackTrace(); - } - } - -} diff --git a/src/test/java/io/reactivesocket/TestUtil.java b/src/test/java/io/reactivesocket/TestUtil.java deleted file mode 100644 index 0ea5d7b12..000000000 --- a/src/test/java/io/reactivesocket/TestUtil.java +++ /dev/null @@ -1,122 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket; - -import org.agrona.MutableDirectBuffer; - -import java.nio.ByteBuffer; -import java.nio.charset.Charset; - -public class TestUtil -{ - private TestUtil() {} - - public static Frame utf8EncodedRequestFrame(final int streamId, final FrameType type, final String data, final int initialRequestN) - { - return Frame.Request.from(streamId, type, new Payload() - { - public ByteBuffer getData() - { - return byteBufferFromUtf8String(data); - } - - public ByteBuffer getMetadata() - { - return Frame.NULL_BYTEBUFFER; - } - }, initialRequestN); - } - - public static Frame utf8EncodedResponseFrame(final int streamId, final FrameType type, final String data) - { - return Frame.Response.from(streamId, type, utf8EncodedPayload(data, null)); - } - - public static Frame utf8EncodedErrorFrame(final int streamId, final String data) - { - return Frame.Error.from(streamId, new Exception(data)); - } - - public static Payload utf8EncodedPayload(final String data, final String metadata) - { - return new PayloadImpl(data, metadata); - } - - public static String byteToString(final ByteBuffer byteBuffer) - { - final byte[] bytes = new byte[byteBuffer.remaining()]; - byteBuffer.get(bytes); - return new String(bytes, Charset.forName("UTF-8")); - } - - public static ByteBuffer byteBufferFromUtf8String(final String data) - { - final byte[] bytes = data.getBytes(Charset.forName("UTF-8")); - return ByteBuffer.wrap(bytes); - } - - public static void copyFrame(final MutableDirectBuffer dst, final int offset, final Frame frame) - { - dst.putBytes(offset, frame.getByteBuffer(), frame.offset(), frame.length()); - } - - private static class PayloadImpl implements Payload // some JDK shoutout - { - private ByteBuffer data; - private ByteBuffer metadata; - - public PayloadImpl(final String data, final String metadata) - { - if (null == data) - { - this.data = ByteBuffer.allocate(0); - } - else - { - this.data = byteBufferFromUtf8String(data); - } - - if (null == metadata) - { - this.metadata = ByteBuffer.allocate(0); - } - else - { - this.metadata = byteBufferFromUtf8String(metadata); - } - } - - public boolean equals(Object obj) - { - System.out.println("equals: " + obj); - final Payload rhs = (Payload)obj; - - return (TestUtil.byteToString(data).equals(TestUtil.byteToString(rhs.getData()))) && - (TestUtil.byteToString(metadata).equals(TestUtil.byteToString(rhs.getMetadata()))); - } - - public ByteBuffer getData() - { - return data; - } - - public ByteBuffer getMetadata() - { - return metadata; - } - } - -} diff --git a/src/test/java/io/reactivesocket/internal/FragmenterTest.java b/src/test/java/io/reactivesocket/internal/FragmenterTest.java deleted file mode 100644 index 686394d08..000000000 --- a/src/test/java/io/reactivesocket/internal/FragmenterTest.java +++ /dev/null @@ -1,203 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import io.reactivesocket.Frame; -import io.reactivesocket.Payload; -import io.reactivesocket.TestUtil; -import io.reactivesocket.internal.frame.FrameHeaderFlyweight; -import io.reactivesocket.internal.frame.PayloadFragmenter; - -import org.junit.Test; - -import static org.junit.Assert.*; - -public class FragmenterTest -{ - private static final int METADATA_MTU = 256; - private static final int DATA_MTU = 256; - private static final int STREAM_ID = 101; - private static final int REQUEST_N = 102; - - @Test - public void shouldPassThroughUnfragmentedResponse() - { - final PayloadFragmenter fragmenter = new PayloadFragmenter(METADATA_MTU, DATA_MTU); - final Payload payload = TestUtil.utf8EncodedPayload("response data", "response metadata"); - - fragmenter.resetForResponse(STREAM_ID, payload); - - assertTrue(fragmenter.hasNext()); - final Frame frame1 = fragmenter.next(); - - assertEquals("response data", TestUtil.byteToString(frame1.getData())); - assertEquals("response metadata", TestUtil.byteToString(frame1.getMetadata())); - assertEquals(0, (frame1.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - assertFalse(fragmenter.hasNext()); - } - - @Test - public void shouldHandleFragmentedResponseData() - { - final String responseData0 = "response "; - final String responseData1 = "data"; - final String responseData = responseData0 + responseData1; - final PayloadFragmenter fragmenter = new PayloadFragmenter(METADATA_MTU, responseData0.length()); - final Payload payload = TestUtil.utf8EncodedPayload(responseData, "response metadata"); - - fragmenter.resetForResponse(STREAM_ID, payload); - - assertTrue(fragmenter.hasNext()); - final Frame frame1 = fragmenter.next(); - - assertEquals(responseData0, TestUtil.byteToString(frame1.getData())); - assertEquals("response metadata", TestUtil.byteToString(frame1.getMetadata())); - assertEquals(FrameHeaderFlyweight.FLAGS_RESPONSE_F, (frame1.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - - assertTrue(fragmenter.hasNext()); - final Frame frame2 = fragmenter.next(); - - assertEquals(responseData1, TestUtil.byteToString(frame2.getData())); - assertEquals("", TestUtil.byteToString(frame2.getMetadata())); - assertEquals(0, (frame2.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - assertFalse(fragmenter.hasNext()); - } - - @Test - public void shouldHandleFragmentedResponseMetadata() - { - final String responseMetadata0 = "response "; - final String responseMetadata1 = "metadata"; - final String responseMetadata = responseMetadata0 + responseMetadata1; - final PayloadFragmenter fragmenter = new PayloadFragmenter(responseMetadata0.length(), DATA_MTU); - final Payload payload = TestUtil.utf8EncodedPayload("response data", responseMetadata); - - fragmenter.resetForResponse(STREAM_ID, payload); - - assertTrue(fragmenter.hasNext()); - final Frame frame1 = fragmenter.next(); - - assertEquals("response data", TestUtil.byteToString(frame1.getData())); - assertEquals(responseMetadata0, TestUtil.byteToString(frame1.getMetadata())); - assertEquals(FrameHeaderFlyweight.FLAGS_RESPONSE_F, (frame1.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - - assertTrue(fragmenter.hasNext()); - final Frame frame2 = fragmenter.next(); - - assertEquals("", TestUtil.byteToString(frame2.getData())); - assertEquals(responseMetadata1, TestUtil.byteToString(frame2.getMetadata())); - assertEquals(0, (frame2.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - assertFalse(fragmenter.hasNext()); - } - - @Test - public void shouldHandleFragmentedResponseMetadataAndData() - { - final String responseMetadata0 = "response "; - final String responseMetadata1 = "metadata"; - final String responseMetadata = responseMetadata0 + responseMetadata1; - final String responseData0 = "response "; - final String responseData1 = "data"; - final String responseData = responseData0 + responseData1; - final PayloadFragmenter fragmenter = new PayloadFragmenter(responseMetadata0.length(), responseData0.length()); - final Payload payload = TestUtil.utf8EncodedPayload(responseData, responseMetadata); - - fragmenter.resetForResponse(STREAM_ID, payload); - - assertTrue(fragmenter.hasNext()); - final Frame frame1 = fragmenter.next(); - - assertEquals(responseData0, TestUtil.byteToString(frame1.getData())); - assertEquals(responseMetadata0, TestUtil.byteToString(frame1.getMetadata())); - assertEquals(FrameHeaderFlyweight.FLAGS_RESPONSE_F, (frame1.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - - assertTrue(fragmenter.hasNext()); - final Frame frame2 = fragmenter.next(); - - assertEquals(responseData1, TestUtil.byteToString(frame2.getData())); - assertEquals(responseMetadata1, TestUtil.byteToString(frame2.getMetadata())); - assertEquals(0, (frame2.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - assertFalse(fragmenter.hasNext()); - } - - @Test - public void shouldHandleFragmentedResponseMetadataAndDataWithMoreThanTwoFragments() - { - final String responseMetadata0 = "response "; - final String responseMetadata1 = "metadata"; - final String responseMetadata = responseMetadata0 + responseMetadata1; - final String responseData0 = "response "; - final String responseData1 = "data "; - final String responseData2 = "and more"; - final String responseData = responseData0 + responseData1 + responseData2; - final PayloadFragmenter fragmenter = new PayloadFragmenter(responseMetadata0.length(), responseData0.length()); - final Payload payload = TestUtil.utf8EncodedPayload(responseData, responseMetadata); - - fragmenter.resetForResponse(STREAM_ID, payload); - - assertTrue(fragmenter.hasNext()); - final Frame frame1 = fragmenter.next(); - - assertEquals(responseData0, TestUtil.byteToString(frame1.getData())); - assertEquals(responseMetadata0, TestUtil.byteToString(frame1.getMetadata())); - assertEquals(FrameHeaderFlyweight.FLAGS_RESPONSE_F, (frame1.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - - assertTrue(fragmenter.hasNext()); - final Frame frame2 = fragmenter.next(); - - assertEquals(responseData1, TestUtil.byteToString(frame2.getData())); - assertEquals(responseMetadata1, TestUtil.byteToString(frame2.getMetadata())); - assertEquals(FrameHeaderFlyweight.FLAGS_RESPONSE_F, (frame2.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - - assertTrue(fragmenter.hasNext()); - final Frame frame3 = fragmenter.next(); - - assertEquals(responseData2, TestUtil.byteToString(frame3.getData())); - assertEquals("", TestUtil.byteToString(frame3.getMetadata())); - assertEquals(0, (frame3.flags() & FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - assertFalse(fragmenter.hasNext()); - } - - @Test - public void shouldHandleFragmentedRequestChannelMetadataAndData() - { - final String requestMetadata0 = "request "; - final String requestMetadata1 = "metadata"; - final String requestMetadata = requestMetadata0 + requestMetadata1; - final String requestData0 = "request "; - final String requestData1 = "data"; - final String requestData = requestData0 + requestData1; - final PayloadFragmenter fragmenter = new PayloadFragmenter(requestMetadata0.length(), requestData0.length()); - final Payload payload = TestUtil.utf8EncodedPayload(requestData, requestMetadata); - - fragmenter.resetForRequestChannel(STREAM_ID, payload, REQUEST_N); - - assertTrue(fragmenter.hasNext()); - final Frame frame1 = fragmenter.next(); - - assertEquals(requestData0, TestUtil.byteToString(frame1.getData())); - assertEquals(requestMetadata0, TestUtil.byteToString(frame1.getMetadata())); - assertEquals(FrameHeaderFlyweight.FLAGS_REQUEST_CHANNEL_F, (frame1.flags() & FrameHeaderFlyweight.FLAGS_REQUEST_CHANNEL_F)); - - assertTrue(fragmenter.hasNext()); - final Frame frame2 = fragmenter.next(); - - assertEquals(requestData1, TestUtil.byteToString(frame2.getData())); - assertEquals(requestMetadata1, TestUtil.byteToString(frame2.getMetadata())); - assertEquals(0, (frame2.flags() & FrameHeaderFlyweight.FLAGS_REQUEST_CHANNEL_F)); - assertFalse(fragmenter.hasNext()); - } -} diff --git a/src/test/java/io/reactivesocket/internal/ReassemblerTest.java b/src/test/java/io/reactivesocket/internal/ReassemblerTest.java deleted file mode 100644 index 0b134d697..000000000 --- a/src/test/java/io/reactivesocket/internal/ReassemblerTest.java +++ /dev/null @@ -1,160 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import io.reactivesocket.Frame; -import io.reactivesocket.FrameType; -import io.reactivesocket.Payload; -import io.reactivesocket.TestUtil; -import io.reactivesocket.internal.frame.FrameHeaderFlyweight; -import io.reactivesocket.internal.frame.PayloadReassembler; -import io.reactivex.subjects.ReplaySubject; -import org.junit.Test; - -import java.nio.ByteBuffer; - -import static org.junit.Assert.assertEquals; - -public class ReassemblerTest -{ - private static final int STREAM_ID = 101; - - @Test - public void shouldPassThroughUnfragmentedFrame() - { - final ReplaySubject replaySubject = ReplaySubject.create(); - final PayloadReassembler reassembler = PayloadReassembler.with(replaySubject); - final String metadata = "metadata"; - final String data = "data"; - final ByteBuffer metadataBuffer = TestUtil.byteBufferFromUtf8String(metadata); - final ByteBuffer dataBuffer = TestUtil.byteBufferFromUtf8String(data); - - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, metadataBuffer, dataBuffer, 0)); - - assertEquals(1, replaySubject.getValues().length); - assertEquals(data, TestUtil.byteToString(replaySubject.getValue().getData())); - assertEquals(metadata, TestUtil.byteToString(replaySubject.getValue().getMetadata())); - } - - @Test - public void shouldNotPassThroughFragmentedFrameIfStillMoreFollowing() - { - final ReplaySubject replaySubject = ReplaySubject.create(); - final PayloadReassembler reassembler = PayloadReassembler.with(replaySubject); - final String metadata = "metadata"; - final String data = "data"; - final ByteBuffer metadataBuffer = TestUtil.byteBufferFromUtf8String(metadata); - final ByteBuffer dataBuffer = TestUtil.byteBufferFromUtf8String(data); - - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, metadataBuffer, dataBuffer, FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - - assertEquals(0, replaySubject.getValues().length); - } - - @Test - public void shouldReassembleTwoFramesWithFragmentedDataAndMetadata() - { - final ReplaySubject replaySubject = ReplaySubject.create(); - final PayloadReassembler reassembler = PayloadReassembler.with(replaySubject); - final String metadata0 = "metadata0"; - final String metadata1 = "md1"; - final String metadata = metadata0 + metadata1; - final String data0 = "data0"; - final String data1 = "d1"; - final String data = data0 + data1; - final ByteBuffer metadata0Buffer = TestUtil.byteBufferFromUtf8String(metadata0); - final ByteBuffer data0Buffer = TestUtil.byteBufferFromUtf8String(data0); - final ByteBuffer metadata1Buffer = TestUtil.byteBufferFromUtf8String(metadata1); - final ByteBuffer data1Buffer = TestUtil.byteBufferFromUtf8String(data1); - - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, metadata0Buffer, data0Buffer, FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, metadata1Buffer, data1Buffer, 0)); - - assertEquals(1, replaySubject.getValues().length); - assertEquals(data, TestUtil.byteToString(replaySubject.getValue().getData())); - assertEquals(metadata, TestUtil.byteToString(replaySubject.getValue().getMetadata())); - } - - @Test - public void shouldReassembleTwoFramesWithFragmentedData() - { - final ReplaySubject replaySubject = ReplaySubject.create(); - final PayloadReassembler reassembler = PayloadReassembler.with(replaySubject); - final String metadata = "metadata"; - final String data0 = "data0"; - final String data1 = "d1"; - final String data = data0 + data1; - final ByteBuffer metadataBuffer = TestUtil.byteBufferFromUtf8String(metadata); - final ByteBuffer data0Buffer = TestUtil.byteBufferFromUtf8String(data0); - final ByteBuffer data1Buffer = TestUtil.byteBufferFromUtf8String(data1); - - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, metadataBuffer, data0Buffer, FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, Frame.NULL_BYTEBUFFER, data1Buffer, 0)); - - assertEquals(1, replaySubject.getValues().length); - assertEquals(data, TestUtil.byteToString(replaySubject.getValue().getData())); - assertEquals(metadata, TestUtil.byteToString(replaySubject.getValue().getMetadata())); - } - - @Test - public void shouldReassembleTwoFramesWithFragmentedMetadata() - { - final ReplaySubject replaySubject = ReplaySubject.create(); - final PayloadReassembler reassembler = PayloadReassembler.with(replaySubject); - final String metadata0 = "metadata0"; - final String metadata1 = "md1"; - final String metadata = metadata0 + metadata1; - final String data = "data"; - final ByteBuffer metadata0Buffer = TestUtil.byteBufferFromUtf8String(metadata0); - final ByteBuffer dataBuffer = TestUtil.byteBufferFromUtf8String(data); - final ByteBuffer metadata1Buffer = TestUtil.byteBufferFromUtf8String(metadata1); - - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, metadata0Buffer, dataBuffer, FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, metadata1Buffer, Frame.NULL_BYTEBUFFER, 0)); - - assertEquals(1, replaySubject.getValues().length); - assertEquals(data, TestUtil.byteToString(replaySubject.getValue().getData())); - assertEquals(metadata, TestUtil.byteToString(replaySubject.getValue().getMetadata())); - } - - @Test - public void shouldReassembleTwoFramesWithFragmentedDataAndMetadataWithMoreThanTwoFragments() - { - final ReplaySubject replaySubject = ReplaySubject.create(); - final PayloadReassembler reassembler = PayloadReassembler.with(replaySubject); - final String metadata0 = "metadata0"; - final String metadata1 = "md1"; - final String metadata = metadata0 + metadata1; - final String data0 = "data0"; - final String data1 = "d1"; - final String data2 = "d2"; - final String data = data0 + data1 + data2; - final ByteBuffer metadata0Buffer = TestUtil.byteBufferFromUtf8String(metadata0); - final ByteBuffer data0Buffer = TestUtil.byteBufferFromUtf8String(data0); - final ByteBuffer metadata1Buffer = TestUtil.byteBufferFromUtf8String(metadata1); - final ByteBuffer data1Buffer = TestUtil.byteBufferFromUtf8String(data1); - final ByteBuffer data2Buffer = TestUtil.byteBufferFromUtf8String(data2); - - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, metadata0Buffer, data0Buffer, FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, metadata1Buffer, data1Buffer, FrameHeaderFlyweight.FLAGS_RESPONSE_F)); - reassembler.onNext(Frame.Response.from(STREAM_ID, FrameType.NEXT, Frame.NULL_BYTEBUFFER, data2Buffer, 0)); - - assertEquals(1, replaySubject.getValues().length); - assertEquals(data, TestUtil.byteToString(replaySubject.getValue().getData())); - assertEquals(metadata, TestUtil.byteToString(replaySubject.getValue().getMetadata())); - } - -} diff --git a/src/test/java/io/reactivesocket/internal/RequesterTest.java b/src/test/java/io/reactivesocket/internal/RequesterTest.java deleted file mode 100644 index 56a680218..000000000 --- a/src/test/java/io/reactivesocket/internal/RequesterTest.java +++ /dev/null @@ -1,275 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import static io.reactivesocket.TestUtil.*; -import static org.junit.Assert.*; -import static io.reactivesocket.ConnectionSetupPayload.NO_FLAGS; -import static io.reactivex.Observable.*; - -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; - -import org.junit.Test; - -import io.reactivesocket.ConnectionSetupPayload; -import io.reactivesocket.Frame; -import io.reactivesocket.FrameType; -import io.reactivesocket.LatchedCompletable; -import io.reactivesocket.Payload; -import io.reactivesocket.TestConnection; -import io.reactivesocket.rx.Completable; -import io.reactivex.subscribers.TestSubscriber; -import io.reactivex.Observable; -import io.reactivex.subjects.ReplaySubject; - -public class RequesterTest -{ - final static Consumer ERROR_HANDLER = Throwable::printStackTrace; - - @Test(timeout=2000) - public void testRequestResponseSuccess() throws InterruptedException { - TestConnection conn = establishConnection(); - ReplaySubject requests = captureRequests(conn); - LatchedCompletable rc = new LatchedCompletable(1); - Requester p = Requester.createClientRequester(conn, ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), ERROR_HANDLER, rc); - rc.await(); - - TestSubscriber ts = new TestSubscriber<>(); - p.requestResponse(utf8EncodedPayload("hello", null)).subscribe(ts); - - ts.assertNoErrors(); - assertEquals(2, requests.getValues().length); - List requested = requests.take(2).toList().toBlocking().single(); - - Frame one = requested.get(0); - assertEquals(0, one.getStreamId());// SETUP always happens on 0 - assertEquals("", byteToString(one.getData())); - assertEquals(FrameType.SETUP, one.getType()); - - Frame two = requested.get(1); - assertEquals(2, two.getStreamId());// need to start at 2, not 0 - assertEquals("hello", byteToString(two.getData())); - assertEquals(FrameType.REQUEST_RESPONSE, two.getType()); - - // now emit a response to ensure the Publisher receives and completes - conn.toInput.send(utf8EncodedResponseFrame(2, FrameType.NEXT_COMPLETE, "world")); - - ts.awaitTerminalEvent(500, TimeUnit.MILLISECONDS); - ts.assertValue(utf8EncodedPayload("world", null)); - ts.assertComplete(); - } - - @Test(timeout=2000) - public void testRequestResponseError() throws InterruptedException { - TestConnection conn = establishConnection(); - ReplaySubject requests = captureRequests(conn); - LatchedCompletable rc = new LatchedCompletable(1); - Requester p = Requester.createClientRequester(conn, ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), ERROR_HANDLER, rc); - rc.await(); - - TestSubscriber ts = new TestSubscriber<>(); - p.requestResponse(utf8EncodedPayload("hello", null)).subscribe(ts); - - assertEquals(2, requests.getValues().length); - List requested = requests.take(2).toList().toBlocking().single(); - - Frame one = requested.get(0); - assertEquals(0, one.getStreamId());// SETUP always happens on 0 - assertEquals("", byteToString(one.getData())); - assertEquals(FrameType.SETUP, one.getType()); - - Frame two = requested.get(1); - assertEquals(2, two.getStreamId());// need to start at 2, not 0 - assertEquals("hello", byteToString(two.getData())); - assertEquals(FrameType.REQUEST_RESPONSE, two.getType()); - - conn.toInput.send(Frame.Error.from(2, new RuntimeException("Failed"))); - ts.awaitTerminalEvent(500, TimeUnit.MILLISECONDS); - ts.assertError(Exception.class); - assertEquals("Failed", ts.errors().get(0).getMessage()); - } - - @Test(timeout=2000) - public void testRequestResponseCancel() throws InterruptedException { - TestConnection conn = establishConnection(); - ReplaySubject requests = captureRequests(conn); - LatchedCompletable rc = new LatchedCompletable(1); - Requester p = Requester.createClientRequester(conn, ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), ERROR_HANDLER, rc); - rc.await(); - - TestSubscriber ts = new TestSubscriber<>(); - p.requestResponse(utf8EncodedPayload("hello", null)).subscribe(ts); - ts.cancel(); - - assertEquals(3, requests.getValues().length); - List requested = requests.take(3).toList().toBlocking().single(); - - Frame one = requested.get(0); - assertEquals(0, one.getStreamId());// SETUP always happens on 0 - assertEquals("", byteToString(one.getData())); - assertEquals(FrameType.SETUP, one.getType()); - - Frame two = requested.get(1); - assertEquals(2, two.getStreamId());// need to start at 2, not 0 - assertEquals("hello", byteToString(two.getData())); - assertEquals(FrameType.REQUEST_RESPONSE, two.getType()); - - Frame three = requested.get(2); - assertEquals(2, three.getStreamId());// still the same stream - assertEquals("", byteToString(three.getData())); - assertEquals(FrameType.CANCEL, three.getType()); - - ts.assertNotTerminated(); - ts.assertNoValues(); - } - - // TODO REQUEST_N on initial frame not implemented yet - @Test(timeout=2000) - public void testRequestStreamSuccess() throws InterruptedException { - TestConnection conn = establishConnection(); - ReplaySubject requests = captureRequests(conn); - LatchedCompletable rc = new LatchedCompletable(1); - Requester p = Requester.createClientRequester(conn, ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), ERROR_HANDLER, rc); - rc.await(); - - TestSubscriber ts = new TestSubscriber<>(); - fromPublisher(p.requestStream(utf8EncodedPayload("hello", null))).map(pl -> byteToString(pl.getData())).subscribe(ts); - - assertEquals(2, requests.getValues().length); - List requested = requests.take(2).toList().toBlocking().single(); - - Frame one = requested.get(0); - assertEquals(0, one.getStreamId());// SETUP always happens on 0 - assertEquals("", byteToString(one.getData())); - assertEquals(FrameType.SETUP, one.getType()); - - Frame two = requested.get(1); - assertEquals(2, two.getStreamId());// need to start at 2, not 0 - assertEquals("hello", byteToString(two.getData())); - assertEquals(FrameType.REQUEST_STREAM, two.getType()); - // TODO assert initial requestN - - // emit data - conn.toInput.send(utf8EncodedResponseFrame(2, FrameType.NEXT, "hello")); - conn.toInput.send(utf8EncodedResponseFrame(2, FrameType.NEXT, "world")); - conn.toInput.send(utf8EncodedResponseFrame(2, FrameType.COMPLETE, "")); - - ts.awaitTerminalEvent(500, TimeUnit.MILLISECONDS); - ts.assertComplete(); - ts.assertValueSequence(Arrays.asList("hello", "world")); - } - - // TODO REQUEST_N on initial frame not implemented yet - @Test(timeout=2000) - public void testRequestStreamSuccessTake2AndCancel() throws InterruptedException { - TestConnection conn = establishConnection(); - ReplaySubject requests = captureRequests(conn); - LatchedCompletable rc = new LatchedCompletable(1); - Requester p = Requester.createClientRequester(conn, ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), ERROR_HANDLER, rc); - rc.await(); - - TestSubscriber ts = new TestSubscriber<>(); - Observable.fromPublisher(p.requestStream(utf8EncodedPayload("hello", null))).take(2).map(pl -> byteToString(pl.getData())).subscribe(ts); - - assertEquals(2, requests.getValues().length); - List requested = requests.take(2).toList().toBlocking().single(); - - Frame one = requested.get(0); - assertEquals(0, one.getStreamId());// SETUP always happens on 0 - assertEquals("", byteToString(one.getData())); - assertEquals(FrameType.SETUP, one.getType()); - - Frame two = requested.get(1); - assertEquals(2, two.getStreamId());// need to start at 2, not 0 - assertEquals("hello", byteToString(two.getData())); - assertEquals(FrameType.REQUEST_STREAM, two.getType()); - // TODO assert initial requestN - - // emit data - conn.toInput.send(utf8EncodedResponseFrame(2, FrameType.NEXT, "hello")); - conn.toInput.send(utf8EncodedResponseFrame(2, FrameType.NEXT, "world")); - - ts.awaitTerminalEvent(500, TimeUnit.MILLISECONDS); - ts.assertComplete(); - ts.assertValueSequence(Arrays.asList("hello", "world")); - - assertEquals(3, requests.getValues().length); - List requested2 = requests.take(3).toList().toBlocking().single(); - - // we should have sent a CANCEL - Frame three = requested2.get(2); - assertEquals(2, three.getStreamId());// still the same stream - assertEquals("", byteToString(three.getData())); - assertEquals(FrameType.CANCEL, three.getType()); - } - - @Test(timeout=2000) - public void testRequestStreamError() throws InterruptedException { - TestConnection conn = establishConnection(); - ReplaySubject requests = captureRequests(conn); - LatchedCompletable rc = new LatchedCompletable(1); - Requester p = Requester.createClientRequester(conn, ConnectionSetupPayload.create("UTF-8", "UTF-8", NO_FLAGS), ERROR_HANDLER, rc); - rc.await(); - - TestSubscriber ts = new TestSubscriber<>(); - p.requestStream(utf8EncodedPayload("hello", null)).subscribe(ts); - - assertEquals(2, requests.getValues().length); - List requested = requests.take(2).toList().toBlocking().single(); - - Frame one = requested.get(0); - assertEquals(0, one.getStreamId());// SETUP always happens on 0 - assertEquals("", byteToString(one.getData())); - assertEquals(FrameType.SETUP, one.getType()); - - Frame two = requested.get(1); - assertEquals(2, two.getStreamId());// need to start at 2, not 0 - assertEquals("hello", byteToString(two.getData())); - assertEquals(FrameType.REQUEST_STREAM, two.getType()); - // TODO assert initial requestN - - // emit data - conn.toInput.send(utf8EncodedResponseFrame(2, FrameType.NEXT, "hello")); - conn.toInput.send(utf8EncodedErrorFrame(2, "Failure")); - - ts.awaitTerminalEvent(500, TimeUnit.MILLISECONDS); - ts.assertError(Exception.class); - ts.assertValue(utf8EncodedPayload("hello", null)); - assertEquals("Failure", ts.errors().get(0).getMessage()); - } - - // @Test // TODO need to implement test for REQUEST_N behavior as a long stream is consumed - public void testRequestStreamRequestNReplenishing() { - // this should REQUEST(1024), receive 768, REQUEST(768), receive ... etc in a back-and-forth - } - - /* **********************************************************************************************/ - - private TestConnection establishConnection() { - return new TestConnection(); - } - - private ReplaySubject captureRequests(TestConnection conn) { - ReplaySubject rs = ReplaySubject.create(); - rs.forEach(i -> System.out.println("capturedRequest => " + i)); - conn.write.add(rs::onNext); - return rs; - } -} diff --git a/src/test/java/io/reactivesocket/internal/ResponderTest.java b/src/test/java/io/reactivesocket/internal/ResponderTest.java deleted file mode 100644 index 283ba7459..000000000 --- a/src/test/java/io/reactivesocket/internal/ResponderTest.java +++ /dev/null @@ -1,348 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import io.reactivesocket.Frame; -import io.reactivesocket.FrameType; -import io.reactivesocket.LatchedCompletable; -import io.reactivesocket.Payload; -import io.reactivesocket.ReactiveSocket; -import io.reactivesocket.RequestHandler; -import io.reactivesocket.TestConnection; -import io.reactivex.Observable; -import io.reactivex.schedulers.Schedulers; -import io.reactivex.schedulers.TestScheduler; -import io.reactivex.subjects.ReplaySubject; -import org.junit.Test; -import org.mockito.Mockito; -import org.reactivestreams.Subscription; - -import java.util.List; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; - -import static io.reactivesocket.LeaseGovernor.NULL_LEASE_GOVERNOR; -import static io.reactivesocket.TestUtil.byteToString; -import static io.reactivesocket.TestUtil.utf8EncodedPayload; -import static io.reactivesocket.TestUtil.utf8EncodedRequestFrame; -import static io.reactivex.Observable.error; -import static io.reactivex.Observable.interval; -import static io.reactivex.Observable.just; -import static io.reactivex.Observable.never; -import static io.reactivex.Observable.range; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; - -public class ResponderTest -{ - final static Consumer ERROR_HANDLER = Throwable::printStackTrace; - - @Test(timeout=2000) - public void testRequestResponseSuccess() throws InterruptedException { - ReactiveSocket reactiveSocket = Mockito.mock(ReactiveSocket.class); - TestConnection conn = establishConnection(); - LatchedCompletable lc = new LatchedCompletable(1); - Responder.createServerResponder(conn, - (setup, rs) -> - new RequestHandler.Builder().withRequestResponse( - request -> - just(utf8EncodedPayload(byteToString(request.getData()) + " world", null))).build(), - NULL_LEASE_GOVERNOR, - ERROR_HANDLER, - lc, - reactiveSocket); - lc.await(); - - ReplaySubject cachedResponses = captureResponses(conn); - sendSetupFrame(conn); - - // perform a request/response - conn.toInput.send(utf8EncodedRequestFrame(1, FrameType.REQUEST_RESPONSE, "hello", 128)); - - assertEquals(1, cachedResponses.getValues().length);// 1 onNext + 1 onCompleted - List frames = cachedResponses.take(1).toList().toBlocking().first(); - - // assert - Frame first = frames.get(0); - assertEquals(1, first.getStreamId()); - assertEquals(FrameType.NEXT_COMPLETE, first.getType()); - assertEquals("hello world", byteToString(first.getData())); - } - - @Test(timeout=2000) - public void testRequestResponseError() throws InterruptedException { - ReactiveSocket reactiveSocket = Mockito.mock(ReactiveSocket.class); - TestConnection conn = establishConnection(); - LatchedCompletable lc = new LatchedCompletable(1); - Responder.createServerResponder(conn, (setup, rs) -> new RequestHandler.Builder() - .withRequestResponse(request -> Observable.error(new Exception("Request Not Found"))).build(), - NULL_LEASE_GOVERNOR, ERROR_HANDLER, lc, reactiveSocket); - lc.await(); - - Observable cachedResponses = captureResponses(conn); - sendSetupFrame(conn); - - // perform a request/response - conn.toInput.send(utf8EncodedRequestFrame(1, FrameType.REQUEST_RESPONSE, "hello", 128)); - - // assert - Frame first = cachedResponses.toBlocking().first(); - assertEquals(1, first.getStreamId()); - assertEquals(FrameType.ERROR, first.getType()); - assertEquals("Request Not Found", byteToString(first.getData())); - } - - @Test(timeout=2000) - public void testRequestResponseCancel() throws InterruptedException { - ReactiveSocket reactiveSocket = Mockito.mock(ReactiveSocket.class); - AtomicBoolean unsubscribed = new AtomicBoolean(); - Observable delayed = never() - .cast(Payload.class) - .doOnCancel(() -> unsubscribed.set(true)); - - TestConnection conn = establishConnection(); - LatchedCompletable lc = new LatchedCompletable(1); - Responder.createServerResponder(conn, (setup, rs) -> new RequestHandler.Builder() - .withRequestResponse(request -> delayed).build(), - NULL_LEASE_GOVERNOR, ERROR_HANDLER, lc, reactiveSocket); - lc.await(); - - ReplaySubject cachedResponses = captureResponses(conn); - sendSetupFrame(conn); - - // perform a request/response - conn.toInput.send(utf8EncodedRequestFrame(1, FrameType.REQUEST_RESPONSE, "hello", 128)); - // assert no response - assertFalse(cachedResponses.hasValue()); - // unsubscribe - assertFalse(unsubscribed.get()); - conn.toInput.send(Frame.Cancel.from(1)); - assertTrue(unsubscribed.get()); - } - - @Test(timeout=2000) - public void testRequestStreamSuccess() throws InterruptedException { - ReactiveSocket reactiveSocket = Mockito.mock(ReactiveSocket.class); - TestConnection conn = establishConnection(); - LatchedCompletable lc = new LatchedCompletable(1); - Responder.createServerResponder(conn, (setup, rs) -> new RequestHandler.Builder() - .withRequestStream( - request -> range(Integer.parseInt(byteToString(request.getData())), 10).map(i -> utf8EncodedPayload(i + "!", null))).build(), - NULL_LEASE_GOVERNOR, ERROR_HANDLER, lc, reactiveSocket); - lc.await(); - - ReplaySubject cachedResponses = captureResponses(conn); - sendSetupFrame(conn); - - // perform a request/response - conn.toInput.send(utf8EncodedRequestFrame(1, FrameType.REQUEST_STREAM, "10", 128)); - - // assert - assertEquals(11, cachedResponses.getValues().length);// 10 onNext + 1 onCompleted - List frames = cachedResponses.take(11).toList().toBlocking().first(); - - // 10 onNext frames - for (int i = 0; i < 10; i++) { - assertEquals(1, frames.get(i).getStreamId()); - assertEquals(FrameType.NEXT, frames.get(i).getType()); - assertEquals((i + 10) + "!", byteToString(frames.get(i).getData())); - } - - // last message is a COMPLETE - assertEquals(1, frames.get(10).getStreamId()); - assertEquals(FrameType.COMPLETE, frames.get(10).getType()); - assertEquals("", byteToString(frames.get(10).getData())); - } - - @Test(timeout=2000) - public void testRequestStreamError() throws InterruptedException { - ReactiveSocket reactiveSocket = Mockito.mock(ReactiveSocket.class); - TestConnection conn = establishConnection(); - LatchedCompletable lc = new LatchedCompletable(1); - Responder.createServerResponder(conn, (setup,rs) -> new RequestHandler.Builder() - .withRequestStream(request -> range(Integer.parseInt(byteToString(request.getData())), 3) - .map(i -> utf8EncodedPayload(i + "!", null)) - .concatWith(error(new Exception("Error Occurred!")))).build(), - NULL_LEASE_GOVERNOR, ERROR_HANDLER, lc, reactiveSocket); - lc.await(); - - ReplaySubject cachedResponses = captureResponses(conn); - sendSetupFrame(conn); - - // perform a request/response - conn.toInput.send(utf8EncodedRequestFrame(1, FrameType.REQUEST_STREAM, "0", 128)); - - // assert - assertEquals(4, cachedResponses.getValues().length);// 3 onNext + 1 onError - List frames = cachedResponses.take(4).toList().toBlocking().first(); - - // 3 onNext frames - for (int i = 0; i < 3; i++) { - assertEquals(1, frames.get(i).getStreamId()); - assertEquals(FrameType.NEXT, frames.get(i).getType()); - assertEquals(i + "!", byteToString(frames.get(i).getData())); - } - - // last message is an ERROR - assertEquals(1, frames.get(3).getStreamId()); - assertEquals(FrameType.ERROR, frames.get(3).getType()); - assertEquals("Error Occurred!", byteToString(frames.get(3).getData())); - } - - @Test(timeout=2000) - public void testRequestStreamCancel() throws InterruptedException { - ReactiveSocket reactiveSocket = Mockito.mock(ReactiveSocket.class); - TestConnection conn = establishConnection(); - TestScheduler ts = Schedulers.test(); - LatchedCompletable lc = new LatchedCompletable(1); - Responder.createServerResponder(conn, (setup,rs) -> new RequestHandler.Builder() - .withRequestStream(request -> interval(1000, TimeUnit.MILLISECONDS, ts).map(i -> utf8EncodedPayload(i + "!", null))).build(), - NULL_LEASE_GOVERNOR, ERROR_HANDLER, lc, reactiveSocket); - lc.await(); - - ReplaySubject cachedResponses = captureResponses(conn); - sendSetupFrame(conn); - - // perform a request/response - conn.toInput.send(utf8EncodedRequestFrame(1, FrameType.REQUEST_STREAM, "/aRequest", 128)); - - // no time has passed, so no values - assertEquals(0, cachedResponses.getValues().length); - ts.advanceTimeBy(1000, TimeUnit.MILLISECONDS); - assertEquals(1, cachedResponses.getValues().length); - ts.advanceTimeBy(2000, TimeUnit.MILLISECONDS); - assertEquals(3, cachedResponses.getValues().length); - // dispose - conn.toInput.send(Frame.Cancel.from(1)); - // still only 1 message - assertEquals(3, cachedResponses.getValues().length); - // advance again, nothing should happen - ts.advanceTimeBy(1000, TimeUnit.MILLISECONDS); - // should still only have 3 message, no ERROR or COMPLETED - assertEquals(3, cachedResponses.getValues().length); - - List frames = cachedResponses.take(3).toList().toBlocking().first(); - - // 3 onNext frames - for (int i = 0; i < 3; i++) { - assertEquals(1, frames.get(i).getStreamId()); - assertEquals(FrameType.NEXT, frames.get(i).getType()); - assertEquals(i + "!", byteToString(frames.get(i).getData())); - } - } - - @Test(timeout=2000) - public void testMultiplexedStreams() throws InterruptedException { - ReactiveSocket reactiveSocket = Mockito.mock(ReactiveSocket.class); - TestScheduler ts = Schedulers.test(); - TestConnection conn = establishConnection(); - LatchedCompletable lc = new LatchedCompletable(1); - Responder.createServerResponder(conn, (setup,rs) -> new RequestHandler.Builder() - .withRequestStream(request -> interval(1000, TimeUnit.MILLISECONDS, ts).map(i -> utf8EncodedPayload(i + "_" + byteToString(request.getData()), null))).build(), - NULL_LEASE_GOVERNOR, ERROR_HANDLER, lc, reactiveSocket); - lc.await(); - - ReplaySubject cachedResponses = captureResponses(conn); - sendSetupFrame(conn); - - // perform a request/response - conn.toInput.send(utf8EncodedRequestFrame(1, FrameType.REQUEST_STREAM, "requestA", 128)); - - // no time has passed, so no values - assertEquals(0, cachedResponses.getValues().length); - ts.advanceTimeBy(1000, TimeUnit.MILLISECONDS); - // we should have 1 message from A - assertEquals(1, cachedResponses.getValues().length); - // now request another stream - conn.toInput.send(utf8EncodedRequestFrame(2, FrameType.REQUEST_STREAM, "requestB", 128)); - // advance some more - ts.advanceTimeBy(2000, TimeUnit.MILLISECONDS); - // should have 3 from A and 2 from B - assertEquals(5, cachedResponses.getValues().length); - // dispose A, but leave B - conn.toInput.send(Frame.Cancel.from(1)); - // still same 5 frames - assertEquals(5, cachedResponses.getValues().length); - // advance again, should get 2 from B - ts.advanceTimeBy(2000, TimeUnit.MILLISECONDS); - assertEquals(7, cachedResponses.getValues().length); - - List frames = cachedResponses.take(7).toList().toBlocking().first(); - - // A frames (positions 0, 1, 3) incrementing 0, 1, 2 - assertEquals(1, frames.get(0).getStreamId()); - assertEquals("0_requestA", byteToString(frames.get(0).getData())); - assertEquals(1, frames.get(1).getStreamId()); - assertEquals("1_requestA", byteToString(frames.get(1).getData())); - assertEquals(1, frames.get(3).getStreamId()); - assertEquals("2_requestA", byteToString(frames.get(3).getData())); - - // B frames (positions 2, 4, 5, 6) incrementing 0, 1, 2, 3 - assertEquals(2, frames.get(2).getStreamId()); - assertEquals("0_requestB", byteToString(frames.get(2).getData())); - assertEquals(2, frames.get(4).getStreamId()); - assertEquals("1_requestB", byteToString(frames.get(4).getData())); - assertEquals(2, frames.get(5).getStreamId()); - assertEquals("2_requestB", byteToString(frames.get(5).getData())); - assertEquals(2, frames.get(6).getStreamId()); - assertEquals("3_requestB", byteToString(frames.get(6).getData())); - } - - /* **********************************************************************************************/ - - private ReplaySubject captureResponses(TestConnection conn) { - // capture all responses to client - ReplaySubject rs = ReplaySubject.create(); - conn.write.add(rs::onNext); - return rs; - } - - private TestConnection establishConnection() { - return new TestConnection(); - } - - private org.reactivestreams.Subscriber PROTOCOL_SUBSCRIBER = new org.reactivestreams.Subscriber() { - - @Override - public void onSubscribe(Subscription s) { - s.request(Long.MAX_VALUE); - } - - @Override - public void onNext(Void t) { - - } - - @Override - public void onError(Throwable t) { - t.printStackTrace(); - } - - @Override - public void onComplete() { - - } - - }; - - - private void sendSetupFrame(TestConnection conn) { - // setup - conn.toInput.send(Frame.Setup.from(0, 0, 0, "UTF-8", "UTF-8", utf8EncodedPayload("", ""))); - } -} diff --git a/src/test/java/io/reactivesocket/internal/UnicastSubjectTest.java b/src/test/java/io/reactivesocket/internal/UnicastSubjectTest.java deleted file mode 100644 index a2ddbd3a8..000000000 --- a/src/test/java/io/reactivesocket/internal/UnicastSubjectTest.java +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2015 Netflix, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.reactivesocket.internal; - -import org.junit.Test; - -import io.reactivesocket.Frame; -import io.reactivesocket.FrameType; -import io.reactivesocket.TestUtil; -import io.reactivesocket.internal.UnicastSubject; -import io.reactivex.subscribers.TestSubscriber; - -import static org.junit.Assert.assertTrue; - -public class UnicastSubjectTest { - - @Test - public void testSubscribeReceiveValue() { - Frame f = TestUtil.utf8EncodedResponseFrame(1, FrameType.NEXT_COMPLETE, "response"); - UnicastSubject us = UnicastSubject.create(); - TestSubscriber ts = new TestSubscriber<>(); - us.subscribe(ts); - us.onNext(f); - ts.assertValue(f); - ts.assertNotTerminated(); - } - - @Test(expected = NullPointerException.class) - public void testNullPointerSendingWithoutSubscriber() { - Frame f = TestUtil.utf8EncodedResponseFrame(1, FrameType.NEXT_COMPLETE, "response"); - UnicastSubject us = UnicastSubject.create(); - us.onNext(f); - } - - @Test - public void testIllegalStateIfMultiSubscribe() { - UnicastSubject us = UnicastSubject.create(); - TestSubscriber f1 = new TestSubscriber<>(); - us.subscribe(f1); - TestSubscriber f2 = new TestSubscriber<>(); - us.subscribe(f2); - - f1.assertNotTerminated(); - for (Throwable e : f2.errors()) { - assertTrue( IllegalStateException.class.isInstance(e) - || NullPointerException.class.isInstance(e)); - } - } - -}