8000 Fix 'initialRequest' guard might be incorrect in DohResolver and potential 'rollback' on 'lastRequest' among concurrent requests. by LinZong · Pull Request #345 · dnsjava/dnsjava · GitHub
[go: up one dir, main page]

Skip to content

Fix 'initialRequest' guard might be incorrect in DohResolver and potential 'rollback' on 'lastRequest' among concurrent requests. #345

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/main/java/org/xbill/DNS/DohResolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ public final class DohResolver implements Resolver {
USE_HTTP_CLIENT = initSuccess;
}

// package-visible for testing
long getNanoTime() {
return System.nanoTime();
}

/**
* Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
*
Expand Down Expand Up @@ -315,7 +320,7 @@ public CompletionStage<Message> sendAsync(Message query, Executor executor) {
private CompletionStage<Message> sendAsync8(final Message query, Executor executor) {
byte[] queryBytes = prepareQuery(query).toWire();
String url = getUrl(queryBytes);
long startTime = System.nanoTime();
long startTime = getNanoTime();
return maxConcurrentRequests
.acquire(timeout)
.handleAsync(
Expand Down Expand Up @@ -363,7 +368,7 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
((HttpsURLConnection) conn).setSSLSocketFactory(sslSocketFactory);
}

Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
conn.setConnectTimeout((int) remainingTimeout.toMillis());
conn.setReadTimeout((int) remainingTimeout.toMillis());
conn.setRequestMethod(usePost ? "POST" : "GET");
Expand All @@ -389,7 +394,7 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
int offset = 0;
while ((r = is.read(responseBytes, offset, responseBytes.length - offset)) > 0) {
offset += r;
remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
if (remainingTimeout.isNegative()) {
throw new SocketTimeoutException();
}
Expand All @@ -403,7 +408,7 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
byte[] buffer = new byte[4096];
int r;
while ((r = is.read(buffer, 0, buffer.length)) > 0) {
remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
if (remainingTimeout.isNegative()) {
throw new SocketTimeoutException();
}
Expand Down Expand Up @@ -432,7 +437,7 @@ private void discardStream(InputStream es) throws IOException {
}

private CompletionStage<Message> sendAsync11(final Message query, Executor executor) {
long startTime = System.nanoTime();
long startTime = getNanoTime();
byte[] queryBytes = prepareQuery(query).toWire();
String url = getUrl(queryBytes);

Expand All @@ -454,7 +459,7 @@ private CompletionStage<Message> sendAsync11(final Message query, Executor execu
// check if this request needs to be done synchronously because of HttpClient's stupidity to
// not use the connection pool for HTTP/2 until one connection is successfully established,
// which could lead to hundreds of connections (and threads with the default executor)
Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
return initialRequestLock
.acquire(remainingTimeout)
.handle(
Expand All @@ -476,14 +481,13 @@ private CompletionStage<Message> sendAsync11WithInitialRequestPermit(
Object requestBuilder,
Permit initialRequestPermit) {
long lastRequestTime = lastRequest.get();
boolean isInitialRequest =
(lastRequestTime < System.nanoTime() - idleConnectionTimeout.toNanos());
boolean isInitialRequest = idleConnectionTimeout.toNanos() > getNanoTime() - lastRequestTime;
if (!isInitialRequest) {
initialRequestPermit.release();
}

// check if we already exceeded the query timeout while checking the initial connection
Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
if (remainingTimeout.isNegative()) {
if (isInitialRequest) {
initialRequestPermit.release();
Expand Down Expand Up @@ -525,7 +529,7 @@ private CompletionStage<Message> sendAsync11WithConcurrentRequestPermit(
boolean isInitialRequest,
Permit maxConcurrentRequestPermit) {
// check if the stream lock acquisition took too long
Duration remainingTimeout = timeout.minus(System.nanoTime() - startTime, ChronoUnit.NANOS);
Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
if (remainingTimeout.isNegative()) {
if (isInitialRequest) {
initialRequestPermit.release();
Expand Down
78 changes: 78 additions & 0 deletions src/test/java/org/xbill/DNS/DohResolverTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.spy;

import io.netty.handler.codec.http.HttpHeaderNames;
import io.vertx.core.Future;
Expand All @@ -20,13 +22,20 @@
import java.time.Duration;
import java.util.Base64;
import java.util.Collections;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledForJreRange;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.stubbing.Answer;

@ExtendWith(VertxExtension.class)
class DohResolverTest {
Expand Down Expand Up @@ -202,6 +211,75 @@ private Future<HttpServer> setupResolverWithServer(
.onSuccess(server -> resolver.setUriTemplate("http://localhost:" + server.actualPort()));
}

@EnabledForJreRange(
min = JRE.JAVA_9,
disabledReason = "Java 8 implementation doesn't have the initial request guard")
@Test
void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(
Vertx vertx, VertxTestContext context) {
AtomicLong startNanos = new AtomicLong(System.nanoTime());
resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2)));
resolver.setTimeout(Duration.ofSeconds(1));
// Simulate a nanoTime value that is lower than the idle timeout
doAnswer((Answer<Long>) invocationOnMock -> System.nanoTime() - startNanos.get())
.when(resolver)
.getNanoTime();

// Just add a 100ms delay before responding to the 1st call
// to simulate a 'concurrent doh request' for the 2nd call,
// then let the fake dns server respond to the 2nd call ASAP.
allRequestsUseTimeout = false;

// idleConnectionTimeout = 2s, lastRequest = 0L
// Ensure idleConnectionTimeout < System.nanoTime() - lastRequest (3s)

// Timeline:
// |<-------- 100ms -------->|
// ↑ ↑
// 1st call sent response of 1st call
// |20ms|<------ 80ms ------>|<------ few millis ------->|
// ↑ wait until 1st call ↑ ↑
// 2nd call begin 2nd call sent response of 2nd call

AtomicBoolean firstCallCompleted = new AtomicBoolean(false);

setupResolverWithServer(Duration.ofMillis(100L), 200, 2, vertx, context)
.onSuccess(
server -> {
// First call
CompletionStage<Message> firstCall = resolver.sendAsync(qm);
// Ensure second call was made after first call and uses a different query
startNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(20));
CompletionStage<Message> secondCall = resolver.sendAsync(Message.newQuery(qr));

Future.fromCompletionStage(firstCall)
.onComplete(
context.succeeding(
result ->
context.verify(
() -> {
assertEquals(Rcode.NOERROR, re 80C6 sult.getHeader().getRcode());
assertEquals(0, result.getHeader().getID());
assertEquals(queryName, result.getQuestion().getName());
firstCallCompleted.set(true);
})));

Future.fromCompletionStage(secondCall)
.onComplete(
context.succeeding(
result ->
context.verify(
() -> {
assertTrue(firstCallCompleted.get());
assertEquals(Rcode.NOERROR, result.getHeader().getRcode());
assertEquals(0, result.getHeader().getID());
assertEquals(queryName, result.getQuestion().getName());
// Complete context after the 2nd call was completed.
context.completeNow();
})));
});
}

private Future<HttpServer> setupServer(
Message expectedDnsRequest,
Message dnsResponse,
Expand Down
Loading
0