httpClients =
- Collections.synchronizedMap(new WeakHashMap<>());
+public final class DohResolver extends DohResolverCommon {
private final SSLSocketFactory sslSocketFactory;
- private static Object defaultHttpRequestBuilder;
- private static Method publisherOfByteArrayMethod;
- private static Method requestBuilderTimeoutMethod;
- private static Method requestBuilderCopyMethod;
- private static Method requestBuilderUriMethod;
- private static Method requestBuilderBuildMethod;
- private static Method requestBuilderPostMethod;
-
- private static Method httpClientNewBuilderMethod;
- private static Method httpClientBuilderTimeoutMethod;
- private static Method httpClientBuilderExecutorMethod;
- private static Method httpClientBuilderBuildMethod;
- private static Method httpClientSendAsyncMethod;
-
- private static Method byteArrayBodyPublisherMethod;
- private static Method httpResponseBodyMethod;
- private static Method httpResponseStatusCodeMethod;
-
- private boolean usePost = false;
- private Duration timeout = Duration.ofSeconds(5);
- private String uriTemplate;
- private final Duration idleConnectionTimeout;
- private OPTRecord queryOPT = new OPTRecord(0, 0, 0);
- private TSIG tsig;
- private Executor defaultExecutor = ForkJoinPool.commonPool();
-
- /**
- * Maximum concurrent HTTP/2 streams or HTTP/1.1 connections.
- *
- * rfc7540#section-6.5.2 recommends a minimum of 100 streams for HTTP/2.
- */
- private final AsyncSemaphore maxConcurrentRequests;
-
- private final AtomicLong lastRequest = new AtomicLong(0);
- private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1);
-
- private static final String APPLICATION_DNS_MESSAGE = "application/dns-message";
-
- static {
- boolean initSuccess = false;
- if (!System.getProperty("java.version").startsWith("1.")) {
- try {
- Class> httpClientBuilderClass = Class.forName("java.net.http.HttpClient$Builder");
- Class> httpClientClass = Class.forName("java.net.http.HttpClient");
- Class> httpVersionEnum = Class.forName("java.net.http.HttpClient$Version");
- Class> httpRequestBuilderClass = Class.forName("java.net.http.HttpRequest$Builder");
- Class> httpRequestClass = Class.forName("java.net.http.HttpRequest");
- Class> bodyPublishersClass = Class.forName("java.net.http.HttpRequest$BodyPublishers");
- Class> bodyPublisherClass = Class.forName("java.net.http.HttpRequest$BodyPublisher");
- Class> httpResponseClass = Class.forName("java.net.http.HttpResponse");
- Class> bodyHandlersClass = Class.forName("java.net.http.HttpResponse$BodyHandlers");
- Class> bodyHandlerClass = Class.forName("java.net.http.HttpResponse$BodyHandler");
-
- // HttpClient.Builder
- httpClientBuilderTimeoutMethod =
- httpClientBuilderClass.getDeclaredMethod("connectTimeout", Duration.class);
- httpClientBuilderExecutorMethod =
- httpClientBuilderClass.getDeclaredMethod("executor", Executor.class);
- httpClientBuilderBuildMethod = httpClientBuilderClass.getDeclaredMethod("build");
-
- // HttpClient
- httpClientNewBuilderMethod = httpClientClass.getDeclaredMethod("newBuilder");
- httpClientSendAsyncMethod =
- httpClientClass.getDeclaredMethod("sendAsync", httpRequestClass, bodyHandlerClass);
-
- // HttpRequestBuilder
- Method requestBuilderHeaderMethod =
- httpRequestBuilderClass.getDeclaredMethod("header", String.class, String.class);
- Method requestBuilderVersionMethod =
- httpRequestBuilderClass.getDeclaredMethod("version", httpVersionEnum);
- requestBuilderTimeoutMethod =
- httpRequestBuilderClass.getDeclaredMethod("timeout", Duration.class);
- requestBuilderUriMethod = httpRequestBuilderClass.getDeclaredMethod("uri", URI.class);
- requestBuilderCopyMethod = httpRequestBuilderClass.getDeclaredMethod("copy");
- requestBuilderBuildMethod = httpRequestBuilderClass.getDeclaredMethod("build");
- requestBuilderPostMethod =
- httpRequestBuilderClass.getDeclaredMethod("POST", bodyPublisherClass);
-
- // HttpRequest
- Method requestBuilderNewBuilderMethod = httpRequestClass.getDeclaredMethod("newBuilder");
-
- // BodyPublishers
- publisherOfByteArrayMethod =
- bodyPublishersClass.getDeclaredMethod("ofByteArray", byte[].class);
-
- // BodyPublisher
- byteArrayBodyPublisherMethod = bodyHandlersClass.getDeclaredMethod("ofByteArray");
-
- // HttpResponse
- httpResponseBodyMethod = httpResponseClass.getDeclaredMethod("body");
- httpResponseStatusCodeMethod = httpResponseClass.getDeclaredMethod("statusCode");
-
- // defaultHttpRequestBuilder = HttpRequest.newBuilder();
- // defaultHttpRequestBuilder.version(HttpClient.Version.HTTP_2);
- // defaultHttpRequestBuilder.header("Content-Type", "application/dns-message");
- // defaultHttpRequestBuilder.header("Accept", "application/dns-message");
- defaultHttpRequestBuilder = requestBuilderNewBuilderMethod.invoke(null);
- @SuppressWarnings({"unchecked", "rawtypes"})
- Enum> http2Version = Enum.valueOf((Class) httpVersionEnum, "HTTP_2");
- requestBuilderVersionMethod.invoke(defaultHttpRequestBuilder, http2Version);
- requestBuilderHeaderMethod.invoke(
- defaultHttpRequestBuilder, "Content-Type", APPLICATION_DNS_MESSAGE);
- requestBuilderHeaderMethod.invoke(
- defaultHttpRequestBuilder, "Accept", APPLICATION_DNS_MESSAGE);
- initSuccess = true;
- } catch (ClassNotFoundException
- | NoSuchMethodException
- | IllegalAccessException
- | InvocationTargetException e) {
- // fallback to Java 8
- log.warn("Java >= 11 detected, but HttpRequest not available");
- }
- }
-
- USE_HTTP_CLIENT = initSuccess;
- }
-
- // package-visible for testing
- long getNanoTime() {
- return System.nanoTime();
- }
-
/**
* Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
*
* @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query}
*/
public DohResolver(String uriTemplate) {
- this(uriTemplate, 100, Duration.ofMinutes(2));
+ this(uriTemplate, 100, Duration.ZERO);
}
/**
@@ -201,22 +68,9 @@ public DohResolver(String uriTemplate) {
*/
public DohResolver(
String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) {
- this.uriTemplate = uriTemplate;
- this.idleConnectionTimeout = idleConnectionTimeout;
- if (maxConcurrentRequests <= 0) {
- throw new IllegalArgumentException("maxConcurrentRequests must be > 0");
- }
- if (!USE_HTTP_CLIENT) {
- try {
- int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5"));
- if (maxConcurrentRequests > javaMaxConn) {
- maxConcurrentRequests = javaMaxConn;
- }
- } catch (NumberFormatException nfe) {
- // well, use what we got
- }
- }
- this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests);
+ super(uriTemplate, maxConcurrentRequests);
+
+ log.debug("Using Java 8 implementation");
try {
sslSocketFactory = SSLContext.getDefault().getSocketFactory();
} catch (NoSuchAlgorithmException e) {
@@ -224,45 +78,6 @@ public DohResolver(
}
}
- @SneakyThrows
- private Object getHttpClient(Executor executor) {
- return httpClients.computeIfAbsent(
- executor,
- key -> {
- try {
- // return HttpClient.newBuilder()
- // .connectTimeout(timeout).
- // .executor(executor)
- // .build();
- Object httpClientBuilder = httpClientNewBuilderMethod.invoke(null);
- httpClientBuilderTimeoutMethod.invoke(httpClientBuilder, timeout);
- httpClientBuilderExecutorMethod.invoke(httpClientBuilder, key);
- return httpClientBuilderBuildMethod.invoke(httpClientBuilder);
- } catch (IllegalAccessException | InvocationTargetException e) {
- log.warn("Could not create a HttpClient with for Executor {}", key, e);
- return null;
- }
- });
- }
-
- /** Not implemented. Specify the port in {@link #setUriTemplate(String)} if required. */
- @Override
- public void setPort(int port) {
- // Not implemented, port is part of the URI
- }
-
- /** Not implemented. */
- @Override
- public void setTCP(boolean flag) {
- // Not implemented, HTTP is always TCP
- }
-
- /** Not implemented. */
- @Override
- public void setIgnoreTruncation(boolean flag) {
- // Not implemented, protocol uses TCP and doesn't have truncation
- }
-
/**
* Sets the EDNS information on outgoing messages.
*
@@ -272,87 +87,85 @@ public void setIgnoreTruncation(boolean flag) {
* @param options EDNS options to be set in the OPT record
*/
@Override
+ @SuppressWarnings("java:S1185") // required for source- and binary compatibility
public void setEDNS(int version, int payloadSize, int flags, List options) {
- switch (version) {
- case -1:
- queryOPT = null;
- break;
-
- case 0:
- queryOPT = new OPTRecord(0, 0, version, flags, options);
- break;
-
- default:
- throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable");
- }
- }
-
- @Override
- public void setTSIGKey(TSIG key) {
- this.tsig = key;
- }
-
- @Override
- public void setTimeout(Duration timeout) {
- this.timeout = timeout;
- httpClients.clear();
- }
-
- @Override
- public Duration getTimeout() {
- return timeout;
+ // required for source- and binary compatibility
+ super.setEDNS(version, payloadSize, flags, options);
}
@Override
+ @SuppressWarnings("java:S1185") // required for source- and binary compatibility
public CompletionStage sendAsync(Message query) {
- return sendAsync(query, defaultExecutor);
+ // required for source- and binary compatibility
+ return this.sendAsync(query, defaultExecutor);
}
@Override
public CompletionStage sendAsync(Message query, Executor executor) {
- if (USE_HTTP_CLIENT) {
- return sendAsync11(query, executor);
- }
-
- return sendAsync8(query, executor);
- }
-
- private CompletionStage sendAsync8(final Message query, Executor executor) {
byte[] queryBytes = prepareQuery(query).toWire();
String url = getUrl(queryBytes);
long startTime = getNanoTime();
- return maxConcurrentRequests
- .acquire(timeout)
- .handleAsync(
- (permit, ex) -> {
- if (ex != null) {
- return this.timeoutFailedFuture(query, ex);
- } else {
- try {
- SendAndGetMessageBytesResponse result =
- sendAndGetMessageBytes(url, queryBytes, startTime);
- Message response;
- if (result.rc == Rcode.NOERROR) {
- response = new Message(result.responseBytes);
- verifyTSIG(query, response, result.responseBytes, tsig);
+ int queryId = query.getHeader().getID();
+
+ CompletableFuture f =
+ maxConcurrentRequests
+ .acquire(timeout, queryId, executor)
+ .handleAsync(
+ (permit, ex) -> {
+ if (ex != null) {
+ return this.timeoutFailedFuture(
+ query, "could not acquire lock to send request", ex);
} else {
- response = new Message(0);
- response.getHeader().setRcode(result.rc);
+ try {
+ SendAndGetMessageBytesResponse result =
+ sendAndGetMessageBytes(url, queryBytes, startTime);
+ Message response;
+ if (result.rc == Rcode.NOERROR) {
+ response = new Message(result.responseBytes);
+ verifyTSIG(query, response, result.responseBytes, tsig);
+ } else {
+ response = new Message(0);
+ response.getHeader().setRcode(result.rc);
+ }
+
+ response.setResolver(this);
+ return CompletableFuture.completedFuture(response);
+ } catch (SocketTimeoutException e) {
+ return this.timeoutFailedFuture(query, e);
+ } catch (IOException | URISyntaxException e) {
+ return this.failedFuture(e);
+ } finally {
+ permit.release(queryId, executor);
+ }
}
+ },
+ executor)
+ .thenCompose(Function.identity())
+ .toCompletableFuture();
- response.setResolver(this);
- return CompletableFuture.completedFuture(response);
- } catch (SocketTimeoutException e) {
- return this.timeoutFailedFuture(query, e);
- } catch (IOException | URISyntaxException e) {
- return this.failedFuture(e);
- } finally {
- permit.release();
- }
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ return TimeoutCompletableFuture.compatTimeout(
+ f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS)
+ .exceptionally(
+ ex -> {
+ if (ex instanceof TimeoutException) {
+ throw new CompletionException(
+ new TimeoutException(
+ "Query "
+ + queryId
+ + " for "
+ + query.getQuestion().getName()
+ + "/"
+ + Type.string(query.getQuestion().getType())
+ + " timed out in remaining "
+ + remainingTimeout.toMillis()
+ + "ms"));
+ } else if (ex instanceof CompletionException) {
+ throw (CompletionException) ex;
}
- },
- executor)
- .thenCompose(Function.identity());
+
+ throw new CompletionException(ex);
+ });
}
@Value
@@ -367,15 +180,28 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
if (conn instanceof HttpsURLConnection) {
((HttpsURLConnection) conn).setSSLSocketFactory(sslSocketFactory);
}
-
- Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- conn.setConnectTimeout((int) remainingTimeout.toMillis());
- conn.setReadTimeout((int) remainingTimeout.toMillis());
conn.setRequestMethod(usePost ? "POST" : "GET");
conn.setRequestProperty("Content-Type", APPLICATION_DNS_MESSAGE);
conn.setRequestProperty("Accept", APPLICATION_DNS_MESSAGE);
+
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ throw new SocketTimeoutException("No time left to connect");
+ }
+
+ conn.setConnectTimeout((int) remainingTimeout.toMillis());
if (usePost) {
conn.setDoOutput(true);
+ }
+
+ conn.connect();
+ remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ throw new SocketTimeoutException("No time left to request data");
+ }
+
+ conn.setReadTimeout((int) remainingTimeout.toMillis());
+ if (usePost) {
conn.getOutputStream().write(queryBytes);
}
@@ -395,13 +221,23 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
while ((r = is.read(responseBytes, offset, responseBytes.length - offset)) > 0) {
offset += r;
remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- if (remainingTimeout.isNegative()) {
- throw new SocketTimeoutException();
+
+ // Don't throw if we just received all data
+ if (offset != responseBytes.length
+ && (remainingTimeout.isNegative() || remainingTimeout.isZero())) {
+ throw new SocketTimeoutException(
+ "Timed out waiting for response data, got "
+ + offset
+ + " of "
+ + responseBytes.length
+ + " expected bytes");
}
}
+
if (offset < responseBytes.length) {
throw new EOFException("Could not read expected content length");
}
+
return new SendAndGetMessageBytesResponse(Rcode.NOERROR, responseBytes);
} else {
try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) {
@@ -409,8 +245,9 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes(
int r;
while ((r = is.read(buffer, 0, buffer.length)) > 0) {
remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- if (remainingTimeout.isNegative()) {
- throw new SocketTimeoutException();
+ if (remainingTimeout.isNegative() || remainingTimeout.isZero()) {
+ throw new SocketTimeoutException(
+ "Timed out waiting for response data, got " + bos.size() + " bytes so far");
}
bos.write(buffer, 0, r);
}
@@ -436,275 +273,10 @@ private void discardStream(InputStream es) throws IOException {
}
}
- private CompletionStage sendAsync11(final Message query, Executor executor) {
- long startTime = getNanoTime();
- byte[] queryBytes = prepareQuery(query).toWire();
- String url = getUrl(queryBytes);
-
- // var requestBuilder = defaultHttpRequestBuilder.copy();
- // requestBuilder.uri(URI.create(url));
- Object requestBuilder;
- try {
- requestBuilder = requestBuilderCopyMethod.invoke(defaultHttpRequestBuilder);
- requestBuilderUriMethod.invoke(requestBuilder, URI.create(url));
- if (usePost) {
- // requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(queryBytes));
- requestBuilderPostMethod.invoke(
- requestBuilder, publisherOfByteArrayMethod.invoke(null, queryBytes));
- }
- } catch (IllegalAccessException | InvocationTargetException e) {
- return failedFuture(e);
- }
-
- // check if this request needs to be done synchronously because of HttpClient's stupidity to
- // not use the connection pool for HTTP/2 until one connection is successfully established,
- // which could lead to hundreds of connections (and threads with the default executor)
- Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- return initialRequestLock
- .acquire(remainingTimeout)
- .handle(
- (initialRequestPermit, initialRequestEx) -> {
- if (initialRequestEx != null) {
- return this.timeoutFailedFuture(query, initialRequestEx);
- } else {
- return sendAsync11WithInitialRequestPermit(
- query, executor, startTime, requestBuilder, initialRequestPermit);
- }
- })
- .thenCompose(Function.identity());
- }
-
- private CompletionStage sendAsync11WithInitialRequestPermit(
- Message query,
- Executor executor,
- long startTime,
- Object requestBuilder,
- Permit initialRequestPermit) {
- long lastRequestTime = lastRequest.get();
- boolean isInitialRequest = idleConnectionTimeout.toNanos() > getNanoTime() - lastRequestTime;
- if (!isInitialRequest) {
- initialRequestPermit.release();
- }
-
- // check if we already exceeded the query timeout while checking the initial connection
- Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- if (remainingTimeout.isNegative()) {
- if (isInitialRequest) {
- initialRequestPermit.release();
- }
- return timeoutFailedFuture(query, null);
- }
-
- // Lock a HTTP/2 stream. Another stupidity of HttpClient to not simply queue the
- // request, but fail with an IOException which also CLOSES the connection... *facepalm*
- return maxConcurrentRequests
- .acquire(remainingTimeout)
- .handle(
- (maxConcurrentRequestPermit, maxConcurrentRequestEx) -> {
- if (maxConcurrentRequestEx != null) {
- if (isInitialRequest) {
- initialRequestPermit.release();
- }
- return this.timeoutFailedFuture(query, maxConcurrentRequestEx);
- } else {
- return sendAsync11WithConcurrentRequestPermit(
- query,
- executor,
- startTime,
- requestBuilder,
- initialRequestPermit,
- isInitialRequest,
- maxConcurrentRequestPermit);
- }
- })
- .thenCompose(Function.identity());
- }
-
- private CompletionStage sendAsync11WithConcurrentRequestPermit(
- Message query,
- Executor executor,
- long startTime,
- Object requestBuilder,
- Permit initialRequestPermit,
- boolean isInitialRequest,
- Permit maxConcurrentRequestPermit) {
- // check if the stream lock acquisition took too long
- Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
- if (remainingTimeout.isNegative()) {
- if (isInitialRequest) {
- initialRequestPermit.release();
- }
- maxConcurrentRequestPermit.release();
- return timeoutFailedFuture(query, null);
- }
-
- // var httpRequest = requestBuilder.timeout(remainingTimeout).build();
- // var bodyHandler = HttpResponse.BodyHandlers.ofByteArray();
- // return getHttpClient(executor).sendAsync(httpRequest, bodyHandler)
- try {
- Object httpClient = getHttpClient(executor);
- requestBuilderTimeoutMethod.invoke(requestBuilder, remainingTimeout);
- Object httpRequest = requestBuilderBuildMethod.invoke(requestBuilder);
- Object bodyHandler = byteArrayBodyPublisherMethod.invoke(null);
- CompletableFuture f =
- ((CompletableFuture>)
- httpClientSendAsyncMethod.invoke(httpClient, httpRequest, bodyHandler))
- .whenComplete(
- (result, ex) -> {
- if (ex == null) {
- lastRequest.set(startTime);
- }
- maxConcurrentRequestPermit.release();
- if (isInitialRequest) {
- initialRequestPermit.release();
- }
- })
- .handleAsync(
- (response, ex) -> {
- if (ex != null) {
- if (ex.getCause().getClass().getSimpleName().equals("HttpTimeoutException")) {
- return this.timeoutFailedFuture(query, ex.getCause());
- } else {
- return this.failedFuture(ex);
- }
- } else {
- try {
- Message responseMessage;
- // int rc = response.statusCode();
- int rc = (int) httpResponseStatusCodeMethod.invoke(response);
- if (rc >= 200 && rc < 300) {
- // byte[] responseBytes = response.body();
- byte[] responseBytes = (byte[]) httpResponseBodyMethod.invoke(response);
- responseMessage = new Message(responseBytes);
- verifyTSIG(query, responseMessage, responseBytes, tsig);
- } else {
- responseMessage = new Message();
- responseMessage.getHeader().setRcode(Rcode.SERVFAIL);
- }
-
- responseMessage.setResolver(this);
- return CompletableFuture.completedFuture(responseMessage);
- } catch (IOException | IllegalAccessException | InvocationTargetException e) {
- return this.failedFuture(e);
- }
- }
- },
- executor)
- .thenCompose(Function.identity());
- return TimeoutCompletableFuture.compatTimeout(
- f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS);
- } catch (IllegalAccessException | InvocationTargetException e) {
- return failedFuture(e);
- }
- }
-
- private CompletableFuture failedFuture(Throwable e) {
+ @Override
+ protected CompletableFuture failedFuture(Throwable e) {
CompletableFuture f = new CompletableFuture<>();
f.completeExceptionally(e);
return f;
}
-
- private CompletableFuture timeoutFailedFuture(Message query, Throwable inner) {
- return failedFuture(
- new IOException(
- "Query "
- + query.getHeader().getID()
- + " for "
- + query.getQuestion().getName()
- + "/"
- + Type.string(query.getQuestion().getType())
- + " timed out",
- inner));
- }
-
- private String getUrl(byte[] queryBytes) {
- String url = uriTemplate;
- if (!usePost) {
- url += "?dns=" + base64.toString(queryBytes, true);
- }
- return url;
- }
-
- private Message prepareQuery(Message query) {
- Message preparedQuery = query.clone();
- preparedQuery.getHeader().setID(0);
- if (queryOPT != null && preparedQuery.getOPT() == null) {
- preparedQuery.addRecord(queryOPT, Section.ADDITIONAL);
- }
-
- if (tsig != null) {
- tsig.apply(preparedQuery, null);
- }
-
- return preparedQuery;
- }
-
- private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) {
- if (tsig == null) {
- return;
- }
-
- int error = tsig.verify(response, b, query.getGeneratedTSIG());
- log.debug(
- "TSIG verify for query {}, {}/{}: {}",
- query.getHeader().getID(),
- query.getQuestion().getName(),
- Type.string(query.getQuestion().getType()),
- Rcode.TSIGstring(error));
- }
-
- /** Returns {@code true} if the HTTP method POST to resolve, {@code false} if GET is used. */
- public boolean isUsePost() {
- return usePost;
- }
-
- /**
- * Sets the HTTP method to use for resolving.
- *
- * @param usePost {@code true} to use POST, {@code false} to use GET (the default).
- */
- public void setUsePost(boolean usePost) {
- this.usePost = usePost;
- }
-
- /** Gets the current URI used for resolving. */
- public String getUriTemplate() {
- return uriTemplate;
- }
-
- /** Sets the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */
- public void setUriTemplate(String uriTemplate) {
- this.uriTemplate = uriTemplate;
- }
-
- /**
- * Gets the default {@link Executor} for request handling, defaults to {@link
- * ForkJoinPool#commonPool()}.
- *
- * @since 3.3
- * @deprecated not applicable if {@link #sendAsync(Message, Executor)} is used.
- */
- @Deprecated
- public Executor getExecutor() {
- return defaultExecutor;
- }
-
- /**
- * Sets the default {@link Executor} for request handling.
- *
- * @param executor The new {@link Executor}, can be {@code null} (which is equivalent to {@link
- * ForkJoinPool#commonPool()}).
- * @since 3.3
- * @deprecated Use {@link #sendAsync(Message, Executor)}.
- */
- @Deprecated
- public void setExecutor(Executor executor) {
- this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor;
- httpClients.clear();
- }
-
- @Override
- public String toString() {
- return "DohResolver {" + (usePost ? "POST " : "GET ") + uriTemplate + "}";
- }
}
diff --git a/src/main/java/org/xbill/DNS/DohResolverCommon.java b/src/main/java/org/xbill/DNS/DohResolverCommon.java
new file mode 100644
index 00000000..0fb8ac07
--- /dev/null
+++ b/src/main/java/org/xbill/DNS/DohResolverCommon.java
@@ -0,0 +1,232 @@
+// SPDX-License-Identifier: BSD-3-Clause
+package org.xbill.DNS;
+
+import java.time.Duration;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ForkJoinPool;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicLong;
+import lombok.extern.slf4j.Slf4j;
+import org.xbill.DNS.utils.base64;
+
+@Slf4j
+abstract class DohResolverCommon implements Resolver {
+ /**
+ * Maximum concurrent HTTP/2 streams or HTTP/1.1 connections.
+ *
+ * rfc7540#section-6.5.2 recommends a minimum of 100 streams for HTTP/2.
+ */
+ protected final AsyncSemaphore maxConcurrentRequests;
+
+ protected final AtomicLong lastRequest = new AtomicLong(0);
+
+ protected static final String APPLICATION_DNS_MESSAGE = "application/dns-message";
+
+ protected boolean usePost = false;
+ protected Duration timeout = Duration.ofSeconds(5);
+ protected String uriTemplate;
+ protected OPTRecord queryOPT = new OPTRecord(0, 0, 0);
+ protected TSIG tsig;
+ protected Executor defaultExecutor = ForkJoinPool.commonPool();
+
+ // package-visible for testing
+ long getNanoTime() {
+ return System.nanoTime();
+ }
+
+ /**
+ * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
+ *
+ * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query}
+ * @param maxConcurrentRequests Maximum concurrent HTTP/2 streams for Java 11+ or HTTP/1.1
+ * connections for Java 8. On Java 8 this cannot exceed the system property {@code
+ * http.maxConnections}.
+ */
+ protected DohResolverCommon(String uriTemplate, int maxConcurrentRequests) {
+ this.uriTemplate = uriTemplate;
+ if (maxConcurrentRequests <= 0) {
+ throw new IllegalArgumentException("maxConcurrentRequests must be > 0");
+ }
+
+ try {
+ int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5"));
+ if (maxConcurrentRequests > javaMaxConn) {
+ maxConcurrentRequests = javaMaxConn;
+ }
+ } catch (NumberFormatException nfe) {
+ // well, use what we got
+ }
+
+ this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests, "concurrent");
+ }
+
+ /** Not implemented. Specify the port in {@link #setUriTemplate(String)} if required. */
+ @Override
+ public void setPort(int port) {
+ // Not implemented, port is part of the URI
+ }
+
+ /** Not implemented. */
+ @Override
+ public void setTCP(boolean flag) {
+ // Not implemented, HTTP is always TCP
+ }
+
+ /** Not implemented. */
+ @Override
+ public void setIgnoreTruncation(boolean flag) {
+ // Not implemented, protocol uses TCP and doesn't have truncation
+ }
+
+ /**
+ * Sets the EDNS information on outgoing messages.
+ *
+ * @param version The EDNS version to use. 0 indicates EDNS0 and -1 indicates no EDNS.
+ * @param payloadSize ignored
+ * @param flags EDNS extended flags to be set in the OPT record.
+ * @param options EDNS options to be set in the OPT record
+ */
+ @Override
+ public void setEDNS(int version, int payloadSize, int flags, List options) {
+ switch (version) {
+ case -1:
+ queryOPT = null;
+ break;
+
+ case 0:
+ queryOPT = new OPTRecord(0, 0, version, flags, options);
+ break;
+
+ default:
+ throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable");
+ }
+ }
+
+ @Override
+ public void setTSIGKey(TSIG key) {
+ this.tsig = key;
+ }
+
+ @Override
+ public void setTimeout(Duration timeout) {
+ this.timeout = timeout;
+ }
+
+ @Override
+ public Duration getTimeout() {
+ return timeout;
+ }
+
+ protected String getUrl(byte[] queryBytes) {
+ String url = uriTemplate;
+ if (!usePost) {
+ url += "?dns=" + base64.toString(queryBytes, true);
+ }
+ return url;
+ }
+
+ protected Message prepareQuery(Message query) {
+ Message preparedQuery = query.clone();
+ preparedQuery.getHeader().setID(0);
+ if (queryOPT != null && preparedQuery.getOPT() == null) {
+ preparedQuery.addRecord(queryOPT, Section.ADDITIONAL);
+ }
+
+ if (tsig != null) {
+ tsig.apply(preparedQuery, null);
+ }
+
+ return preparedQuery;
+ }
+
+ protected void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) {
+ if (tsig == null) {
+ return;
+ }
+
+ int error = tsig.verify(response, b, query.getGeneratedTSIG());
+ log.debug(
+ "TSIG verify for query {}, {}/{}: {}",
+ query.getHeader().getID(),
+ query.getQuestion().getName(),
+ Type.string(query.getQuestion().getType()),
+ Rcode.TSIGstring(error));
+ }
+
+ /** Returns {@code true} if the HTTP method POST to resolve, {@code false} if GET is used. */
+ public boolean isUsePost() {
+ return usePost;
+ }
+
+ /**
+ * Sets the HTTP method to use for resolving.
+ *
+ * @param usePost {@code true} to use POST, {@code false} to use GET (the default).
+ */
+ public void setUsePost(boolean usePost) {
+ this.usePost = usePost;
+ }
+
+ /** Gets the current URI used for resolving. */
+ public String getUriTemplate() {
+ return uriTemplate;
+ }
+
+ /** Sets the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */
+ public void setUriTemplate(String uriTemplate) {
+ this.uriTemplate = uriTemplate;
+ }
+
+ /**
+ * Gets the default {@link Executor} for request handling, defaults to {@link
+ * ForkJoinPool#commonPool()}.
+ *
+ * @since 3.3
+ * @deprecated not applicable if {@link #sendAsync(Message, Executor)} is used.
+ */
+ @Deprecated
+ public Executor getExecutor() {
+ return defaultExecutor;
+ }
+
+ /**
+ * Sets the default {@link Executor} for request handling.
+ *
+ * @param executor The new {@link Executor}, can be {@code null} (which is equivalent to {@link
+ * ForkJoinPool#commonPool()}).
+ * @since 3.3
+ * @deprecated Use {@link #sendAsync(Message, Executor)}.
+ */
+ @Deprecated
+ public void setExecutor(Executor executor) {
+ this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor;
+ }
+
+ @Override
+ public String toString() {
+ return "DohResolver {" + (usePost ? "POST " : "GET ") + uriTemplate + "}";
+ }
+
+ protected abstract CompletableFuture failedFuture(Throwable e);
+
+ protected final CompletableFuture timeoutFailedFuture(Message query, Throwable inner) {
+ return timeoutFailedFuture(query, null, inner);
+ }
+
+ protected final CompletableFuture timeoutFailedFuture(
+ Message query, String message, Throwable inner) {
+ return failedFuture(
+ new TimeoutException(
+ "Query "
+ + query.getHeader().getID()
+ + " for "
+ + query.getQuestion().getName()
+ + "/"
+ + Type.string(query.getQuestion().getType())
+ + " timed out"
+ + (message != null ? ": " + message : "")
+ + (inner != null && inner.getMessage() != null ? ", " + inner.getMessage() : "")));
+ }
+}
diff --git a/src/main/java/org/xbill/DNS/Lookup.java b/src/main/java/org/xbill/DNS/Lookup.java
index 135be61e..875d03b0 100644
--- a/src/main/java/org/xbill/DNS/Lookup.java
+++ b/src/main/java/org/xbill/DNS/Lookup.java
@@ -8,6 +8,7 @@
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -43,7 +44,7 @@
public final class Lookup {
private static Resolver defaultResolver;
- private static List defaultSearchPath;
+ private static List defaultSearchPath = Collections.emptyList();
private static Map defaultCaches;
private static int defaultNdots;
private static HostsFileParser defaultHostsFileParser;
@@ -173,6 +174,11 @@ public static synchronized List getDefaultSearchPath() {
* made absolute.
*/
public static synchronized void setDefaultSearchPath(List domains) {
+ if (domains == null) {
+ defaultSearchPath = Collections.emptyList();
+ return;
+ }
+
defaultSearchPath = convertSearchPathDomainList(domains);
}
@@ -184,6 +190,11 @@ public static synchronized void setDefaultSearchPath(List domains) {
* made absolute.
*/
public static synchronized void setDefaultSearchPath(Name... domains) {
+ if (domains == null) {
+ defaultSearchPath = Collections.emptyList();
+ return;
+ }
+
setDefaultSearchPath(Arrays.asList(domains));
}
@@ -196,16 +207,16 @@ public static synchronized void setDefaultSearchPath(Name... domains) {
public static synchronized void setDefaultSearchPath(String... domains)
throws TextParseException {
if (domains == null) {
- defaultSearchPath = null;
+ defaultSearchPath = Collections.emptyList();
return;
}
- List newdomains = new ArrayList<>(domains.length);
+ List newDomains = new ArrayList<>(domains.length);
for (String domain : domains) {
- newdomains.add(Name.fromString(domain, Name.root));
+ newDomains.add(Name.fromString(domain, Name.root));
}
- defaultSearchPath = newdomains;
+ defaultSearchPath = newDomains;
}
/**
@@ -395,6 +406,11 @@ public void setResolver(Resolver resolver) {
* made absolute.
*/
public void setSearchPath(List domains) {
+ if (domains == null) {
+ this.searchPath = Collections.emptyList();
+ return;
+ }
+
this.searchPath = convertSearchPathDomainList(domains);
}
@@ -406,6 +422,11 @@ public void setSearchPath(List domains) {
* made absolute.
*/
public void setSearchPath(Name... domains) {
+ if (domains == null) {
+ this.searchPath = Collections.emptyList();
+ return;
+ }
+
setSearchPath(Arrays.asList(domains));
}
@@ -417,15 +438,16 @@ public void setSearchPath(Name... domains) {
*/
public void setSearchPath(String... domains) throws TextParseException {
if (domains == null) {
- this.searchPath = null;
+ this.searchPath = Collections.emptyList();
return;
}
- List newdomains = new ArrayList<>(domains.length);
+ List newDomains = new ArrayList<>(domains.length);
for (String domain : domains) {
- newdomains.add(Name.fromString(domain, Name.root));
+ newDomains.add(Name.fromString(domain, Name.root));
}
- this.searchPath = newdomains;
+
+ this.searchPath = newDomains;
}
/**
@@ -671,8 +693,6 @@ public Record[] run() {
}
if (name.isAbsolute()) {
resolve(name, null);
- } else if (searchPath == null) {
- resolve(name, Name.root);
} else {
if (name.labels() > ndots) {
resolve(name, Name.root);
diff --git a/src/main/java/org/xbill/DNS/Message.java b/src/main/java/org/xbill/DNS/Message.java
index bb214c8d..767450b8 100644
--- a/src/main/java/org/xbill/DNS/Message.java
+++ b/src/main/java/org/xbill/DNS/Message.java
@@ -810,9 +810,11 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord)
List additionalSectionSets = getSectionRRsets(Section.ADDITIONAL);
List authoritySectionSets = getSectionRRsets(Section.AUTHORITY);
- List cleanedAnswerSection = new ArrayList<>();
- List cleanedAuthoritySection = new ArrayList<>();
- List cleanedAdditionalSection = new ArrayList<>();
+ @SuppressWarnings("unchecked")
+ List[] cleanedSection = new ArrayList[4];
+ cleanedSection[Section.ANSWER] = new ArrayList<>();
+ cleanedSection[Section.AUTHORITY] = new ArrayList<>();
+ cleanedSection[Section.ADDITIONAL] = new ArrayList<>();
boolean hadNsInAuthority = false;
// For the ANSWER section, remove all "irrelevant" records and add synthesized CNAMEs from
@@ -843,7 +845,7 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord)
// If DNAME was queried, don't attempt to synthesize CNAME
if (query.getQuestion().getType() != Type.DNAME) {
// The DNAME is valid, accept it
- cleanedAnswerSection.add(rrset);
+ cleanedSection[Section.ANSWER].add(rrset);
// Check if the next rrset is correct CNAME, otherwise synthesize a CNAME
RRset nextRRSet = answerSectionSets.size() >= i + 2 ? answerSectionSets.get(i + 1) : null;
@@ -863,7 +865,7 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord)
// Add a synthesized CNAME; TTL=0 to avoid caching
Name dnameTarget = sname.fromDNAME(dname);
- cleanedAnswerSection.add(
+ cleanedSection[Section.ANSWER].add(
new RRset(new CNAMERecord(sname, dname.getDClass(), 0, dnameTarget)));
sname = dnameTarget;
@@ -872,7 +874,7 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord)
for (i++; i < answerSectionSets.size(); i++) {
rrset = answerSectionSets.get(i);
if (rrset.getName().equals(oldSname)) {
- cleanedAnswerSection.add(rrset);
+ cleanedSection[Section.ANSWER].add(rrset);
} else {
break;
}
@@ -943,14 +945,14 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord)
}
sname = ((CNAMERecord) rrset.first()).getTarget();
- cleanedAnswerSection.add(rrset);
+ cleanedSection[Section.ANSWER].add(rrset);
// In CNAME ANY response, can have data after CNAME
if (query.getQuestion().getType() == Type.ANY) {
for (i++; i < answerSectionSets.size(); i++) {
rrset = answerSectionSets.get(i);
if (rrset.getName().equals(oldSname)) {
- cleanedAnswerSection.add(rrset);
+ cleanedSection[Section.ANSWER].add(rrset);
} else {
break;
}
@@ -973,9 +975,9 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord)
}
// Mark the additional names from relevant RRset as OK
- cleanedAnswerSection.add(rrset);
+ cleanedSection[Section.ANSWER].add(rrset);
if (sname.equals(rrset.getName())) {
- addAdditionalRRset(rrset, additionalSectionSets, cleanedAdditionalSection);
+ addAdditionalRRset(rrset, additionalSectionSets, cleanedSection[Section.ADDITIONAL]);
}
}
@@ -1045,15 +1047,25 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord)
}
}
- cleanedAuthoritySection.add(rrset);
- addAdditionalRRset(rrset, additionalSectionSets, cleanedAdditionalSection);
+ cleanedSection[Section.AUTHORITY].add(rrset);
+ addAdditionalRRset(rrset, additionalSectionSets, cleanedSection[Section.ADDITIONAL]);
}
Message cleanedMessage = new Message(this.getHeader());
cleanedMessage.sections[Section.QUESTION] = this.sections[Section.QUESTION];
- cleanedMessage.sections[Section.ANSWER] = rrsetListToRecords(cleanedAnswerSection);
- cleanedMessage.sections[Section.AUTHORITY] = rrsetListToRecords(cleanedAuthoritySection);
- cleanedMessage.sections[Section.ADDITIONAL] = rrsetListToRecords(cleanedAdditionalSection);
+ for (int section : new int[] {Section.ANSWER, Section.AUTHORITY, Section.ADDITIONAL}) {
+ cleanedMessage.sections[section] = rrsetListToRecords(cleanedSection[section]);
+
+ // Fixup counts in the header
+ cleanedMessage
+ .getHeader()
+ .setCount(
+ section,
+ cleanedMessage.sections[section] == null
+ ? 0
+ : cleanedMessage.sections[section].size());
+ }
+
return cleanedMessage;
}
diff --git a/src/main/java/org/xbill/DNS/NioClient.java b/src/main/java/org/xbill/DNS/NioClient.java
index d9718e00..14211a29 100644
--- a/src/main/java/org/xbill/DNS/NioClient.java
+++ b/src/main/java/org/xbill/DNS/NioClient.java
@@ -10,6 +10,7 @@
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.util.Iterator;
+import java.util.function.Consumer;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@@ -24,9 +25,9 @@
* The following configuration parameter is available:
*
*
- * - dnsjava.nio.selector_timeout
+ *
- {@value SELECTOR_TIMEOUT_PROPERTY}
*
- Set selector timeout in milliseconds. Default/Max 1000, Min 1.
- *
- dnsjava.nio.register_shutdown_hook
+ *
- {@value REGISTER_SHUTDOWN_HOOK_PROPERTY}
*
- Register Shutdown Hook termination of NIO. Default True.
*
*
@@ -35,16 +36,24 @@
@Slf4j
@NoArgsConstructor(access = AccessLevel.NONE)
public abstract class NioClient {
+ static final String SELECTOR_TIMEOUT_PROPERTY = "dnsjava.nio.selector_timeout";
+ static final String REGISTER_SHUTDOWN_HOOK_PROPERTY = "dnsjava.nio.register_shutdown_hook";
+ private static final Object NIO_CLIENT_LOCK = new Object();
+
/** Packet logger, if available. */
private static PacketLogger packetLogger = null;
private static final Runnable[] TIMEOUT_TASKS = new Runnable[2];
- private static final Runnable[] REGISTRATIONS_TASKS = new Runnable[2];
+
+ @SuppressWarnings("unchecked")
+ private static final Consumer[] REGISTRATIONS_TASKS = new Consumer[2];
+
private static final Runnable[] CLOSE_TASKS = new Runnable[2];
private static Thread selectorThread;
private static Thread closeThread;
private static volatile Selector selector;
private static volatile boolean run;
+ private static volatile boolean closeDone;
interface KeyProcessor {
void processReadyKey(SelectionKey key);
@@ -52,7 +61,7 @@ interface KeyProcessor {
static Selector selector() throws IOException {
if (selector == null) {
- synchronized (NioClient.class) {
+ synchronized (NIO_CLIENT_LOCK) {
if (selector == null) {
selector = Selector.open();
log.debug("Starting dnsjava NIO selector thread");
@@ -63,8 +72,7 @@ static Selector selector() throws IOException {
selectorThread.start();
closeThread = new Thread(() -> close(true));
closeThread.setName("dnsjava NIO shutdown hook");
- if (Boolean.parseBoolean(
- System.getProperty("dnsjava.nio.register_shutdown_hook", "true"))) {
+ if (Boolean.parseBoolean(System.getProperty(REGISTER_SHUTDOWN_HOOK_PROPERTY, "true"))) {
Runtime.getRuntime().addShutdownHook(closeThread);
}
}
@@ -74,13 +82,23 @@ static Selector selector() throws IOException {
return selector;
}
- /** Shutdown the network I/O used by the {@link SimpleResolver}. */
+ /**
+ * Shutdown the network I/O used by the {@link SimpleResolver}.
+ *
+ * @implNote Does not wait until the selector thread has stopped. But users may immediately start
+ * using the {@link NioClient} again.
+ * @since 3.4.0
+ */
public static void close() {
close(false);
}
private static void close(boolean fromHook) {
run = false;
+ Selector localSelector = selector;
+ if (localSelector != null) {
+ selector.wakeup();
+ }
if (!fromHook) {
try {
@@ -90,40 +108,26 @@ private static void close(boolean fromHook) {
}
}
- try {
- runTasks(CLOSE_TASKS);
- } catch (Exception e) {
- log.warn("Failed to execute shutdown task, ignoring and continuing close", e);
- }
-
- Selector localSelector = selector;
- Thread localSelectorThread = selectorThread;
- synchronized (NioClient.class) {
- selector = null;
- selectorThread = null;
- closeThread = null;
+ if (localSelector == null) {
+ // Prevent hanging when close() was called without starting
+ return;
}
- if (localSelector != null) {
- localSelector.wakeup();
+ synchronized (NIO_CLIENT_LOCK) {
try {
- localSelector.close();
- } catch (IOException e) {
- log.warn("Failed to properly close selector, ignoring and continuing close", e);
- }
- }
-
- if (localSelectorThread != null) {
- try {
- localSelectorThread.join();
+ while (!closeDone) {
+ NIO_CLIENT_LOCK.wait();
+ }
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
+ } finally {
+ closeDone = false;
}
}
}
static void runSelector() {
- int timeout = Integer.getInteger("dnsjava.nio.selector_timeout", 1000);
+ int timeout = Integer.getInteger(SELECTOR_TIMEOUT_PROPERTY, 1000);
if (timeout <= 0 || timeout > 1000) {
throw new IllegalArgumentException("Invalid selector_timeout, must be between 1 and 1000");
@@ -136,7 +140,7 @@ static void runSelector() {
}
if (run) {
- runTasks(REGISTRATIONS_TASKS);
+ runRegistrationTasks();
processReadyKeys();
}
} catch (IOException e) {
@@ -145,30 +149,74 @@ static void runSelector() {
// ignore
}
}
+
+ runClose();
log.debug("dnsjava NIO selector thread stopped");
}
- static synchronized void setTimeoutTask(Runnable r, boolean isTcpClient) {
+ private static void runClose() {
+ try {
+ runTasks(CLOSE_TASKS);
+ } catch (Exception e) {
+ log.warn("Failed to execute shutdown task, ignoring and continuing close", e);
+ }
+
+ Selector localSelector = selector;
+ Thread localSelectorThread = selectorThread;
+ synchronized (NIO_CLIENT_LOCK) {
+ selector = null;
+ selectorThread = null;
+ closeThread = null;
+ closeDone = true;
+ NIO_CLIENT_LOCK.notifyAll();
+ }
+
+ if (localSelector != null) {
+ try {
+ localSelector.close();
+ } catch (IOException e) {
+ log.warn("Failed to properly close selector, ignoring and continuing close", e);
+ }
+ }
+
+ if (localSelectorThread != null) {
+ try {
+ localSelectorThread.join();
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ }
+ }
+ }
+
+ static void setTimeoutTask(Runnable r, boolean isTcpClient) {
addTask(TIMEOUT_TASKS, r, isTcpClient);
}
- static synchronized void setRegistrationsTask(Runnable r, boolean isTcpClient) {
- addTask(REGISTRATIONS_TASKS, r, isTcpClient);
+ static void setRegistrationsTask(Consumer r, boolean isTcpClient) {
+ addRegistrationTask(r, isTcpClient);
}
- static synchronized void setCloseTask(Runnable r, boolean isTcpClient) {
+ static void setCloseTask(Runnable r, boolean isTcpClient) {
addTask(CLOSE_TASKS, r, isTcpClient);
}
- private static void addTask(Runnable[] closeTasks, Runnable r, boolean isTcpClient) {
+ private static void addTask(Runnable[] tasks, Runnable r, boolean isTcpClient) {
if (isTcpClient) {
- closeTasks[0] = r;
+ tasks[0] = r;
} else {
- closeTasks[1] = r;
+ tasks[1] = r;
}
}
- private static synchronized void runTasks(Runnable[] runnables) {
+ private static void addRegistrationTask(Consumer r, boolean isTcpClient) {
+ if (isTcpClient) {
+ REGISTRATIONS_TASKS[0] = r;
+ } else {
+ REGISTRATIONS_TASKS[1] = r;
+ }
+ }
+
+ private static void runTasks(Runnable[] runnables) {
Runnable r0 = runnables[0];
if (r0 != null) {
r0.run();
@@ -179,6 +227,17 @@ private static synchronized void runTasks(Runnable[] runnables) {
}
}
+ private static void runRegistrationTasks() {
+ Consumer r0 = REGISTRATIONS_TASKS[0];
+ if (r0 != null) {
+ r0.accept(selector);
+ }
+ Consumer r1 = REGISTRATIONS_TASKS[1];
+ if (r1 != null) {
+ r1.accept(selector);
+ }
+ }
+
private static void processReadyKeys() {
Iterator it = selector.selectedKeys().iterator();
while (it.hasNext()) {
diff --git a/src/main/java/org/xbill/DNS/NioTcpClient.java b/src/main/java/org/xbill/DNS/NioTcpClient.java
index 566c00eb..200a5331 100644
--- a/src/main/java/org/xbill/DNS/NioTcpClient.java
+++ b/src/main/java/org/xbill/DNS/NioTcpClient.java
@@ -32,7 +32,7 @@ final class NioTcpClient extends NioClient implements TcpIoClient {
setCloseTask(this::closeTcp, true);
}
- private void processPendingRegistrations() {
+ private void processPendingRegistrations(Selector selector) {
while (!registrationQueue.isEmpty()) {
ChannelState state = registrationQueue.poll();
if (state == null) {
@@ -40,7 +40,6 @@ private void processPendingRegistrations() {
}
try {
- final Selector selector = selector();
if (!state.channel.isConnected()) {
state.channel.register(selector, SelectionKey.OP_CONNECT, state);
} else {
diff --git a/src/main/java/org/xbill/DNS/NioUdpClient.java b/src/main/java/org/xbill/DNS/NioUdpClient.java
index ce6607f2..e2ccfef2 100644
--- a/src/main/java/org/xbill/DNS/NioUdpClient.java
+++ b/src/main/java/org/xbill/DNS/NioUdpClient.java
@@ -54,7 +54,7 @@ final class NioUdpClient extends NioClient implements UdpIoClient {
setCloseTask(this::closeUdp, false);
}
- private void processPendingRegistrations() {
+ private void processPendingRegistrations(Selector selector) {
while (!registrationQueue.isEmpty()) {
Transaction t = registrationQueue.poll();
if (t == null) {
@@ -63,7 +63,7 @@ private void processPendingRegistrations() {
try {
log.trace("Registering OP_READ for transaction with id {}", t.id);
- t.channel.register(selector(), SelectionKey.OP_READ, t);
+ t.channel.register(selector, SelectionKey.OP_READ, t);
t.send();
} catch (IOException e) {
t.completeExceptionally(e);
diff --git a/src/main/java/org/xbill/DNS/Resolver.java b/src/main/java/org/xbill/DNS/Resolver.java
index 5d3c0ff3..5b6ce511 100644
--- a/src/main/java/org/xbill/DNS/Resolver.java
+++ b/src/main/java/org/xbill/DNS/Resolver.java
@@ -9,6 +9,7 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
@@ -239,11 +240,16 @@ default Object sendAsync(Message query, ResolverListener listener) {
(result, throwable) -> {
if (throwable != null) {
Exception exception;
+ if (throwable instanceof CompletionException && throwable.getCause() != null) {
+ throwable = throwable.getCause();
+ }
+
if (throwable instanceof Exception) {
exception = (Exception) throwable;
} else {
exception = new Exception(throwable);
}
+
listener.handleException(id, exception);
return null;
}
diff --git a/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java b/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java
index 24796df6..e2ae60f6 100644
--- a/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java
+++ b/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java
@@ -1,8 +1,6 @@
// SPDX-License-Identifier: BSD-3-Clause
package org.xbill.DNS;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.ScheduledThreadPoolExecutor;
@@ -10,57 +8,28 @@
import java.util.concurrent.TimeoutException;
import lombok.extern.slf4j.Slf4j;
-/**
- * Utility class to backport {@code orTimeout} to Java 8 with a custom implementation. On Java 9+
- * the built-in method is called.
- */
+/** Utility class to backport {@code orTimeout} to Java 8 with a custom implementation. */
@Slf4j
class TimeoutCompletableFuture extends CompletableFuture {
- private static final Method orTimeoutMethod;
-
- static {
- Method localOrTimeoutMethod;
- if (!System.getProperty("java.version").startsWith("1.")) {
- try {
- localOrTimeoutMethod =
- CompletableFuture.class.getMethod("orTimeout", long.class, TimeUnit.class);
- } catch (NoSuchMethodException e) {
- localOrTimeoutMethod = null;
- log.warn(
- "CompletableFuture.orTimeout method not found in Java 9+, using custom implementation",
- e);
- }
- } else {
- localOrTimeoutMethod = null;
- }
- orTimeoutMethod = localOrTimeoutMethod;
- }
-
public CompletableFuture compatTimeout(long timeout, TimeUnit unit) {
return compatTimeout(this, timeout, unit);
}
- @SuppressWarnings("unchecked")
public static CompletableFuture compatTimeout(
CompletableFuture f, long timeout, TimeUnit unit) {
- if (orTimeoutMethod == null) {
- return orTimeout(f, timeout, unit);
- } else {
- try {
- return (CompletableFuture) orTimeoutMethod.invoke(f, timeout, unit);
- } catch (IllegalAccessException | InvocationTargetException e) {
- return orTimeout(f, timeout, unit);
- }
+ if (timeout <= 0) {
+ f.completeExceptionally(new TimeoutException("timeout is " + timeout + ", but must be > 0"));
}
- }
- private static CompletableFuture orTimeout(
- CompletableFuture f, long timeout, TimeUnit unit) {
ScheduledFuture> sf =
TimeoutScheduler.executor.schedule(
() -> {
if (!f.isDone()) {
- f.completeExceptionally(new TimeoutException());
+ f.completeExceptionally(
+ new TimeoutException(
+ "Timeout of "
+ + unit.toMillis(timeout)
+ + "ms has elapsed before the task completed"));
}
},
timeout,
diff --git a/src/main/java/org/xbill/DNS/hosts/HostsFileParser.java b/src/main/java/org/xbill/DNS/hosts/HostsFileParser.java
index c617bc6d..cd5dfb88 100644
--- a/src/main/java/org/xbill/DNS/hosts/HostsFileParser.java
+++ b/src/main/java/org/xbill/DNS/hosts/HostsFileParser.java
@@ -8,12 +8,15 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
+import java.time.Clock;
+import java.time.Duration;
import java.time.Instant;
import java.util.Arrays;
-import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicInteger;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.xbill.DNS.Address;
@@ -31,12 +34,23 @@
public final class HostsFileParser {
private final int maxFullCacheFileSizeBytes =
Integer.parseInt(System.getProperty("dnsjava.hostsfile.max_size_bytes", "16384"));
+ private final Duration fileChangeCheckInterval =
+ Duration.ofMillis(
+ Integer.parseInt(
+ System.getProperty("dnsjava.hostsfile.change_check_interval_ms", "300000")));
- private final Map hostsCache = new HashMap<>();
private final Path path;
private final boolean clearCacheOnChange;
+ private Clock clock = Clock.systemUTC();
+
+ @SuppressWarnings("java:S3077")
+ private volatile Map hostsCache;
+
+ private Instant lastFileModificationCheckTime = Instant.MIN;
private Instant lastFileReadTime = Instant.MIN;
private boolean isEntireFileParsed;
+ private boolean hostsFileWarningLogged = false;
+ private long hostsFileSizeBytes;
/**
* Creates a new instance based on the current OS's default. Unix and alike (or rather everything
@@ -86,8 +100,7 @@ public HostsFileParser(Path path, boolean clearCacheOnChange) {
* @throws IllegalArgumentException when {@code type} is not {@link org.xbill.DNS.Type#A} or{@link
* org.xbill.DNS.Type#AAAA}.
*/
- public synchronized Optional getAddressForHost(Name name, int type)
- throws IOException {
+ public Optional getAddressForHost(Name name, int type) throws IOException {
Objects.requireNonNull(name, "name is required");
if (type != Type.A && type != Type.AAAA) {
throw new IllegalArgumentException("type can only be A or AAAA");
@@ -100,13 +113,11 @@ public synchronized Optional getAddressForHost(Name name, int type)
return Optional.of(cachedAddress);
}
- if (isEntireFileParsed || !Files.exists(path)) {
+ if (isEntireFileParsed) {
return Optional.empty();
}
- if (Files.size(path) <= maxFullCacheFileSizeBytes) {
- parseEntireHostsFile();
- } else {
+ if (hostsFileSizeBytes > maxFullCacheFileSizeBytes) {
searchHostsFileForEntry(name, type);
}
@@ -116,9 +127,11 @@ public synchronized Optional getAddressForHost(Name name, int type)
private void parseEntireHostsFile() throws IOException {
String line;
int lineNumber = 0;
+ AtomicInteger addressFailures = new AtomicInteger(0);
+ AtomicInteger nameFailures = new AtomicInteger(0);
try (BufferedReader hostsReader = Files.newBufferedReader(path, StandardCharsets.UTF_8)) {
while ((line = hostsReader.readLine()) != null) {
- LineData lineData = parseLine(++lineNumber, line);
+ LineData lineData = parseLine(++lineNumber, line, addressFailures, nameFailures);
if (lineData != null) {
for (Name lineName : lineData.names) {
InetAddress lineAddress =
@@ -129,15 +142,24 @@ private void parseEntireHostsFile() throws IOException {
}
}
- isEntireFileParsed = true;
+ if (!hostsFileWarningLogged && (addressFailures.get() > 0 || nameFailures.get() > 0)) {
+ log.warn(
+ "Failed to parse entire hosts file {}, address failures={}, name failures={}",
+ path,
+ addressFailures.get(),
+ nameFailures);
+ hostsFileWarningLogged = true;
+ }
}
private void searchHostsFileForEntry(Name name, int type) throws IOException {
String line;
int lineNumber = 0;
+ AtomicInteger addressFailures = new AtomicInteger(0);
+ AtomicInteger nameFailures = new AtomicInteger(0);
try (BufferedReader hostsReader = Files.newBufferedReader(path, StandardCharsets.UTF_8)) {
while ((line = hostsReader.readLine()) != null) {
- LineData lineData = parseLine(++lineNumber, line);
+ LineData lineData = parseLine(++lineNumber, line, addressFailures, nameFailures);
if (lineData != null) {
for (Name lineName : lineData.names) {
boolean isSearchedEntry = lineName.equals(name);
@@ -151,6 +173,16 @@ private void searchHostsFileForEntry(Name name, int type) throws IOException {
}
}
}
+
+ if (!hostsFileWarningLogged && (addressFailures.get() > 0 || nameFailures.get() > 0)) {
+ log.warn(
+ "Failed to find {} in hosts file {}, address failures={}, name failures={}",
+ name,
+ path,
+ addressFailures.get(),
+ nameFailures);
+ hostsFileWarningLogged = true;
+ }
}
@RequiredArgsConstructor
@@ -160,7 +192,8 @@ private static final class LineData {
final Iterable extends Name> names;
}
- private LineData parseLine(int lineNumber, String line) {
+ private LineData parseLine(
+ int lineNumber, String line, AtomicInteger addressFailures, AtomicInteger nameFailures) {
String[] lineTokens = getLineTokens(line);
if (lineTokens.length < 2) {
return null;
@@ -174,24 +207,26 @@ private LineData parseLine(int lineNumber, String line) {
}
if (lineAddressBytes == null) {
- log.warn("Could not decode address {}, {}#L{}", lineTokens[0], path, lineNumber);
+ log.debug("Could not decode address {}, {}#L{}", lineTokens[0], path, lineNumber);
+ addressFailures.incrementAndGet();
return null;
}
Iterable extends Name> lineNames =
Arrays.stream(lineTokens)
.skip(1)
- .map(lineTokenName -> safeName(lineTokenName, lineNumber))
+ .map(lineTokenName -> safeName(lineTokenName, lineNumber, nameFailures))
.filter(Objects::nonNull)
::iterator;
return new LineData(lineAddressType, lineAddressBytes, lineNames);
}
- private Name safeName(String name, int lineNumber) {
+ private Name safeName(String name, int lineNumber, AtomicInteger nameFailures) {
try {
return Name.fromString(name, Name.root);
} catch (TextParseException e) {
- log.warn("Could not decode name {}, {}#L{}, skipping", name, path, lineNumber);
+ log.debug("Could not decode name {}, {}#L{}, skipping", name, path, lineNumber);
+ nameFailures.incrementAndGet();
return null;
}
}
@@ -207,21 +242,61 @@ private String[] getLineTokens(String line) {
}
private void validateCache() throws IOException {
- if (clearCacheOnChange) {
+ if (!clearCacheOnChange) {
+ if (hostsCache == null) {
+ synchronized (this) {
+ if (hostsCache == null) {
+ readHostsFile();
+ }
+ }
+ }
+
+ return;
+ }
+
+ if (lastFileModificationCheckTime.plus(fileChangeCheckInterval).isBefore(clock.instant())) {
+ log.debug("Checked for changes more than 5minutes ago, checking");
// A filewatcher / inotify etc. would be nicer, but doesn't work. c.f. the write up at
// https://blog.arkey.fr/2019/09/13/watchservice-and-bind-mount/
- Instant fileTime =
- Files.exists(path) ? Files.getLastModifiedTime(path).toInstant() : Instant.MAX;
- if (fileTime.isAfter(lastFileReadTime)) {
- // skip logging noise when the cache is empty anyway
- if (!hostsCache.isEmpty()) {
- log.info("Local hosts database has changed at {}, clearing cache", fileTime);
- hostsCache.clear();
+
+ synchronized (this) {
+ if (!lastFileModificationCheckTime
+ .plus(fileChangeCheckInterval)
+ .isBefore(clock.instant())) {
+ log.debug("Never mind, check fulfilled in another thread");
+ return;
+ }
+
+ lastFileModificationCheckTime = clock.instant();
+ readHostsFile();
+ }
+ }
+ }
+
+ private void readHostsFile() throws IOException {
+ if (Files.exists(path)) {
+ Instant fileTime = Files.getLastModifiedTime(path).toInstant();
+ if (!lastFileReadTime.equals(fileTime)) {
+ createOrClearCache();
+
+ hostsFileSizeBytes = Files.size(path);
+ if (hostsFileSizeBytes <= maxFullCacheFileSizeBytes) {
+ parseEntireHostsFile();
+ isEntireFileParsed = true;
}
- isEntireFileParsed = false;
lastFileReadTime = fileTime;
}
+ } else {
+ createOrClearCache();
+ }
+ }
+
+ private void createOrClearCache() {
+ if (hostsCache == null) {
+ hostsCache = new ConcurrentHashMap<>();
+ } else {
+ hostsCache.clear();
}
}
@@ -231,6 +306,10 @@ private String key(Name name, int type) {
// for unit testing only
int cacheSize() {
- return hostsCache.size();
+ return hostsCache == null ? 0 : hostsCache.size();
+ }
+
+ void setClock(Clock clock) {
+ this.clock = clock;
}
}
diff --git a/src/main/java11/org/xbill/DNS/AsyncSemaphore.java b/src/main/java11/org/xbill/DNS/AsyncSemaphore.java
new file mode 100644
index 00000000..d2dcf6dd
--- /dev/null
+++ b/src/main/java11/org/xbill/DNS/AsyncSemaphore.java
@@ -0,0 +1,66 @@
+// SPDX-License-Identifier: BSD-3-Clause
+package org.xbill.DNS;
+
+import java.time.Duration;
+import java.util.ArrayDeque;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionStage;
+import java.util.concurrent.Executor;
+import java.util.concurrent.TimeUnit;
+import lombok.extern.slf4j.Slf4j;
+
+@Slf4j
+final class AsyncSemaphore {
+ private final Queue> queue = new ArrayDeque<>();
+ private final Permit singletonPermit = new Permit();
+ private final String name;
+ private volatile int permits;
+
+ final class Permit {
+ public void release(int id, Executor executor) {
+ synchronized (queue) {
+ CompletableFuture next = queue.poll();
+ if (next == null) {
+ permits++;
+ log.trace("{} permit released id={}, available={}", name, id, permits);
+ } else {
+ log.trace("{} permit released id={}, available={}, immediate next", name, id, permits);
+ next.completeAsync(() -> this, executor);
+ }
+ }
+ }
+ }
+
+ AsyncSemaphore(int permits, String name) {
+ this.permits = permits;
+ this.name = name;
+ log.debug("Using Java 11+ implementation for {}", name);
+ }
+
+ CompletionStage acquire(Duration timeout, int id, Executor executor) {
+ synchronized (queue) {
+ if (permits > 0) {
+ permits--;
+ log.trace("{} permit acquired id={}, available={}", name, id, permits);
+ return CompletableFuture.completedFuture(singletonPermit);
+ } else {
+ CompletableFuture f = new CompletableFuture<>();
+ f.orTimeout(timeout.toNanos(), TimeUnit.NANOSECONDS)
+ .whenCompleteAsync(
+ (result, ex) -> {
+ synchronized (queue) {
+ if (ex != null) {
+ log.trace("{} permit timed out id={}, available={}", name, id, permits);
+ }
+ queue.remove(f);
+ }
+ },
+ executor);
+ log.trace("{} permit queued id={}, available={}", name, id, permits);
+ queue.add(f);
+ return f;
+ }
+ }
+ }
+}
diff --git a/src/main/java11/org/xbill/DNS/DohResolver.java b/src/main/java11/org/xbill/DNS/DohResolver.java
new file mode 100644
index 00000000..8e73eca5
--- /dev/null
+++ b/src/main/java11/org/xbill/DNS/DohResolver.java
@@ -0,0 +1,311 @@
+// SPDX-License-Identifier: BSD-3-Clause
+package org.xbill.DNS;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.net.http.HttpTimeoutException;
+import java.time.Duration;
+import java.time.temporal.ChronoUnit;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.WeakHashMap;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.CompletionStage;
+import java.util.concurrent.Executor;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.function.Function;
+import lombok.SneakyThrows;
+import lombok.extern.slf4j.Slf4j;
+import org.xbill.DNS.AsyncSemaphore.Permit;
+
+/**
+ * Proof-of-concept DNS over HTTP (DoH)
+ * resolver. This class is not suitable for high load scenarios because of the shortcomings of
+ * Java's built-in HTTP clients. For more control, implement your own {@link Resolver} using e.g. OkHttp.
+ *
+ * On Java 8, it uses HTTP/1.1, which is against the recommendation of RFC 8484 to use HTTP/2 and
+ * thus slower. On Java 11 or newer, HTTP/2 is always used, but the built-in HttpClient has its own
+ * issues with connection handling.
+ *
+ *
As of 2020-09-13, the following limits of public resolvers for HTTP/2 were observed:
+ *
https://cloudflare-dns.com/dns-query: max streams=250, idle timeout=400s
+ * https://dns.google/dns-query: max streams=100, idle timeout=240s
+ *
+ * @since 3.0
+ */
+@Slf4j
+public final class DohResolver extends DohResolverCommon {
+ private static final String APPLICATION_DNS_MESSAGE = "application/dns-message";
+ private static final Map httpClients =
+ Collections.synchronizedMap(new WeakHashMap<>());
+ private static final HttpRequest.Builder defaultHttpRequestBuilder;
+
+ private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1, "initial request");
+
+ private final Duration idleConnectionTimeout;
+
+ static {
+ defaultHttpRequestBuilder = HttpRequest.newBuilder();
+ defaultHttpRequestBuilder.version(HttpClient.Version.HTTP_2);
+ defaultHttpRequestBuilder.header("Content-Type", APPLICATION_DNS_MESSAGE);
+ defaultHttpRequestBuilder.header("Accept", APPLICATION_DNS_MESSAGE);
+ }
+
+ /**
+ * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
+ *
+ * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query}
+ */
+ public DohResolver(String uriTemplate) {
+ this(uriTemplate, 100, Duration.ofMinutes(2));
+ }
+
+ /**
+ * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s).
+ *
+ * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query}
+ * @param maxConcurrentRequests Maximum concurrent HTTP/2 streams for Java 11+ or HTTP/1.1
+ * connections for Java 8. On Java 8 this cannot exceed the system property {@code
+ * http.maxConnections}.
+ * @param idleConnectionTimeout Max. idle time for HTTP/2 connections until a request is
+ * serialized. Applies to Java 11+ only.
+ * @since 3.3
+ */
+ public DohResolver(
+ String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) {
+ super(uriTemplate, maxConcurrentRequests);
+ log.debug("Using Java 11+ implementation");
+ this.idleConnectionTimeout = idleConnectionTimeout;
+ }
+
+ @SneakyThrows
+ private HttpClient getHttpClient(Executor executor) {
+ return httpClients.computeIfAbsent(
+ executor,
+ key -> {
+ try {
+ return HttpClient.newBuilder().connectTimeout(timeout).executor(executor).build();
+ } catch (IllegalArgumentException e) {
+ log.warn("Could not create a HttpClient for Executor {}", key, e);
+ return null;
+ }
+ });
+ }
+
+ @Override
+ public void setTimeout(Duration timeout) {
+ this.timeout = timeout;
+ httpClients.clear();
+ }
+
+ /**
+ * Sets the EDNS information on outgoing messages.
+ *
+ * @param version The EDNS version to use. 0 indicates EDNS0 and -1 indicates no EDNS.
+ * @param payloadSize ignored
+ * @param flags EDNS extended flags to be set in the OPT record.
+ * @param options EDNS options to be set in the OPT record
+ */
+ @Override
+ @SuppressWarnings("java:S1185") // required for source- and binary compatibility
+ public void setEDNS(int version, int payloadSize, int flags, List options) {
+ // required for source- and binary compatibility
+ super.setEDNS(version, payloadSize, flags, options);
+ }
+
+ @Override
+ @SuppressWarnings("java:S1185") // required for source- and binary compatibility
+ public CompletionStage sendAsync(Message query) {
+ return this.sendAsync(query, defaultExecutor);
+ }
+
+ @Override
+ public CompletionStage sendAsync(Message query, Executor executor) {
+ long startTime = getNanoTime();
+ byte[] queryBytes = prepareQuery(query).toWire();
+ String url = getUrl(queryBytes);
+
+ var requestBuilder = defaultHttpRequestBuilder.copy();
+ requestBuilder.uri(URI.create(url));
+ if (usePost) {
+ requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(queryBytes));
+ }
+
+ // check if this request needs to be done synchronously because of HttpClient's stupidity to
+ // not use the connection pool for HTTP/2 until one connection is successfully established,
+ // which could lead to hundreds of connections (and threads with the default executor)
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ return timeoutFailedFuture(query, "no time left to acquire lock for first request", null);
+ }
+
+ return initialRequestLock
+ .acquire(remainingTimeout, query.getHeader().getID(), executor)
+ .handle(
+ (initialRequestPermit, initialRequestEx) -> {
+ if (initialRequestEx != null) {
+ return this.timeoutFailedFuture(query, initialRequestEx);
+ } else {
+ return sendAsyncWithInitialRequestPermit(
+ query, executor, startTime, requestBuilder, initialRequestPermit);
+ }
+ })
+ .thenCompose(Function.identity());
+ }
+
+ private CompletionStage sendAsyncWithInitialRequestPermit(
+ Message query,
+ Executor executor,
+ long startTime,
+ HttpRequest.Builder requestBuilder,
+ Permit initialRequestPermit) {
+ int queryId = query.getHeader().getID();
+ long lastRequestTime = lastRequest.get();
+ long requestDeltaNanos = getNanoTime() - lastRequestTime;
+ boolean isInitialRequest =
+ lastRequestTime == 0 || idleConnectionTimeout.toNanos() < requestDeltaNanos;
+ if (!isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+
+ // check if we already exceeded the query timeout while checking the initial connection
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ if (isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+
+ return timeoutFailedFuture(
+ query, "no time left to acquire lock for concurrent request", null);
+ }
+
+ // Lock a HTTP/2 stream. Another stupidity of HttpClient to not simply queue the
+ // request, but fail with an IOException which also CLOSES the connection... *facepalm*
+ return maxConcurrentRequests
+ .acquire(remainingTimeout, queryId, executor)
+ .handle(
+ (maxConcurrentRequestPermit, maxConcurrentRequestEx) -> {
+ if (maxConcurrentRequestEx != null) {
+ if (isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+ return this.timeoutFailedFuture(
+ query,
+ "timed out waiting for a concurrent request lease",
+ maxConcurrentRequestEx);
+ } else {
+ return sendAsyncWithConcurrentRequestPermit(
+ query,
+ executor,
+ startTime,
+ requestBuilder,
+ initialRequestPermit,
+ isInitialRequest,
+ maxConcurrentRequestPermit);
+ }
+ })
+ .thenCompose(Function.identity());
+ }
+
+ private CompletionStage sendAsyncWithConcurrentRequestPermit(
+ Message query,
+ Executor executor,
+ long startTime,
+ HttpRequest.Builder requestBuilder,
+ Permit initialRequestPermit,
+ boolean isInitialRequest,
+ Permit maxConcurrentRequestPermit) {
+ int queryId = query.getHeader().getID();
+
+ // check if the stream lock acquisition took too long
+ Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS);
+ if (remainingTimeout.toMillis() <= 0) {
+ if (isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+
+ maxConcurrentRequestPermit.release(queryId, executor);
+ return timeoutFailedFuture(
+ query, "no time left to acquire lock for concurrent request", null);
+ }
+
+ var httpRequest = requestBuilder.timeout(remainingTimeout).build();
+ var bodyHandler = HttpResponse.BodyHandlers.ofByteArray();
+ return getHttpClient(executor)
+ .sendAsync(httpRequest, bodyHandler)
+ .whenComplete(
+ (result, ex) -> {
+ if (ex == null) {
+ lastRequest.set(startTime);
+ }
+ maxConcurrentRequestPermit.release(queryId, executor);
+ if (isInitialRequest) {
+ initialRequestPermit.release(queryId, executor);
+ }
+ })
+ .handleAsync(
+ (response, ex) -> {
+ if (ex != null) {
+ if (ex instanceof HttpTimeoutException) {
+ return this.timeoutFailedFuture(
+ query, "http request did not complete", ex.getCause());
+ } else {
+ return CompletableFuture.failedFuture(ex);
+ }
+ } else {
+ try {
+ Message responseMessage;
+ int rc = response.statusCode();
+ if (rc >= 200 && rc < 300) {
+ byte[] responseBytes = response.body();
+ responseMessage = new Message(responseBytes);
+ verifyTSIG(query, responseMessage, responseBytes, tsig);
+ } else {
+ responseMessage = new Message();
+ responseMessage.getHeader().setRcode(Rcode.SERVFAIL);
+ }
+
+ responseMessage.setResolver(this);
+ return CompletableFuture.completedFuture(responseMessage);
+ } catch (IOException e) {
+ return CompletableFuture.failedFuture(e);
+ }
+ }
+ },
+ executor)
+ .thenCompose(Function.identity())
+ .orTimeout(remainingTimeout.toMillis(), TimeUnit.MILLISECONDS)
+ .exceptionally(
+ ex -> {
+ if (ex instanceof TimeoutException) {
+ throw new CompletionException(
+ new TimeoutException(
+ "Query "
+ + query.getHeader().getID()
+ + " for "
+ + query.getQuestion().getName()
+ + "/"
+ + Type.string(query.getQuestion().getType())
+ + " timed out in remaining "
+ + remainingTimeout.toMillis()
+ + "ms"));
+ } else if (ex instanceof CompletionException) {
+ throw (CompletionException) ex;
+ }
+
+ throw new CompletionException(ex);
+ });
+ }
+
+ @Override
+ protected CompletableFuture failedFuture(Throwable e) {
+ return CompletableFuture.failedFuture(e);
+ }
+}
diff --git a/src/test/java/org/xbill/DNS/DohResolverTest.java b/src/test/java/org/xbill/DNS/DohResolverTest.java
index 26daf19c..6d1500d6 100644
--- a/src/test/java/org/xbill/DNS/DohResolverTest.java
+++ b/src/test/java/org/xbill/DNS/DohResolverTest.java
@@ -28,8 +28,8 @@
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
+import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledForJreRange;
import org.junit.jupiter.api.condition.JRE;
import org.junit.jupiter.api.extension.ExtendWith;
@@ -38,8 +38,8 @@
import org.mockito.stubbing.Answer;
@ExtendWith(VertxExtension.class)
+@Slf4j
class DohResolverTest {
- private DohResolver resolver;
private final Name queryName = Name.fromConstantString("example.com.");
private final Record qr = Record.newRecord(queryName, Type.A, DClass.IN);
private final Message qm = Message.newQuery(qr);
@@ -48,7 +48,6 @@ class DohResolverTest {
@BeforeEach
void beforeEach() throws UnknownHostException {
- resolver = new DohResolver("http://localhost");
Record ar =
new ARecord(
Name.fromConstantString("example.com."),
@@ -59,11 +58,16 @@ void beforeEach() throws UnknownHostException {
a.addRecord(ar, Section.ANSWER);
}
+ private DohResolver getResolver() {
+ return new DohResolver("http://localhost");
+ }
+
@ParameterizedTest
@ValueSource(booleans = {false, true})
void simpleResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = getResolver();
resolver.setUsePost(usePost);
- setupResolverWithServer(Duration.ZERO, 200, 1, vertx, context)
+ setupResolverWithServer(resolver, Duration.ZERO, 200, 1, vertx, context)
.onSuccess(
server ->
Future.fromCompletionStage(resolver.sendAsync(qm))
@@ -79,10 +83,13 @@ void simpleResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
}))));
}
- @Test
- void timeoutResolve(Vertx vertx, VertxTestContext context) {
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void timeoutResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = getResolver();
resolver.setTimeout(Duration.ofSeconds(1));
- setupResolverWithServer(Duration.ofSeconds(5), 200, 1, vertx, context)
+ resolver.setUsePost(usePost);
+ setupResolverWithServer(resolver, Duration.ofSeconds(5), 200, 1, vertx, context)
.onSuccess(
server ->
Future.fromCompletionStage(resolver.sendAsync(qm))
@@ -98,9 +105,12 @@ void timeoutResolve(Vertx vertx, VertxTestContext context) {
}))));
}
- @Test
- void servfailResolve(Vertx vertx, VertxTestContext context) {
- setupResolverWithServer(Duration.ZERO, 301, 1, vertx, context)
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void servfailResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = getResolver();
+ resolver.setUsePost(usePost);
+ setupResolverWithServer(resolver, Duration.ZERO, 301, 1, vertx, context)
.onSuccess(
server ->
Future.fromCompletionStage(resolver.sendAsync(qm))
@@ -114,12 +124,14 @@ void servfailResolve(Vertx vertx, VertxTestContext context) {
}))));
}
- @Test
- void limitRequestsResolve(Vertx vertx, VertxTestContext context) {
- resolver = new DohResolver("http://localhost", 5, Duration.ofMinutes(2));
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void limitRequestsResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = new DohResolver("http://localhost", 5, Duration.ofMinutes(2));
+ resolver.setUsePost(usePost);
int requests = 100;
Checkpoint cpPass = context.checkpoint(requests);
- setupResolverWithServer(Duration.ofMillis(100), 200, 5, vertx, context)
+ setupResolverWithServer(resolver, Duration.ofMillis(100), 200, 5, vertx, context)
.onSuccess(
server -> {
for (int i = 0; i < requests; i++) {
@@ -137,13 +149,15 @@ void limitRequestsResolve(Vertx vertx, VertxTestContext context) {
});
}
- @Test
- void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) {
- resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void initialRequestSlowResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
+ resolver.setUsePost(usePost);
int requests = 20;
allRequestsUseTimeout = false;
Checkpoint cpPass = context.checkpoint(requests);
- setupResolverWithServer(Duration.ofSeconds(1), 200, 2, vertx, context)
+ setupResolverWithServer(resolver, Duration.ofSeconds(1), 200, 2, vertx, context)
.onSuccess(
server -> {
for (int i = 0; i < requests; i++) {
@@ -161,19 +175,23 @@ void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) {
});
}
- @Test
- void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
- resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
+ void initialRequestTimeoutResolve(boolean usePost, Vertx vertx, VertxTestContext context) {
+ DohResolver resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2));
+ resolver.setUsePost(usePost);
resolver.setTimeout(Duration.ofSeconds(1));
int requests = 20;
allRequestsUseTimeout = false;
Checkpoint cpPass = context.checkpoint(requests - 1);
Checkpoint cpFail = context.checkpoint();
- setupResolverWithServer(Duration.ofSeconds(2), 200, 2, vertx, context)
+ setupResolverWithServer(resolver, Duration.ofSeconds(2), 200, 2, vertx, context)
.onSuccess(
server -> {
+ Message q = qm.clone();
+ q.getHeader().setID(0);
resolver
- .sendAsync(qm)
+ .sendAsync(q)
.whenComplete(
(result, ex) -> {
if (ex == null) {
@@ -185,9 +203,11 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
vertx.setTimer(
1000,
timer -> {
- for (int i = 0; i < requests - 1; i++) {
+ for (int i = 1; i < requests; i++) {
+ Message qq = qm.clone();
+ qq.getHeader().setID(i);
resolver
- .sendAsync(qm)
+ .sendAsync(qq)
.whenComplete(
(result, ex) -> {
if (ex == null) {
@@ -202,6 +222,7 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) {
}
private Future setupResolverWithServer(
+ DohResolver resolver,
Duration responseDelay,
int statusCode,
int maxConcurrentRequests,
@@ -214,12 +235,14 @@ private Future setupResolverWithServer(
@EnabledForJreRange(
min = JRE.JAVA_9,
disabledReason = "Java 8 implementation doesn't have the initial request guard")
- @Test
+ @ParameterizedTest
+ @ValueSource(booleans = {false, true})
void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(
- Vertx vertx, VertxTestContext context) {
+ boolean usePost, Vertx vertx, VertxTestContext context) {
AtomicLong startNanos = new AtomicLong(System.nanoTime());
- resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2)));
+ DohResolver resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2)));
resolver.setTimeout(Duration.ofSeconds(1));
+ resolver.setUsePost(usePost);
// Simulate a nanoTime value that is lower than the idle timeout
doAnswer((Answer) invocationOnMock -> System.nanoTime() - startNanos.get())
.when(resolver)
@@ -243,11 +266,13 @@ void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(
AtomicBoolean firstCallCompleted = new AtomicBoolean(false);
- setupResolverWithServer(Duration.ofMillis(100L), 200, 2, vertx, context)
+ setupResolverWithServer(resolver, Duration.ofMillis(100L), 200, 2, vertx, context)
.onSuccess(
server -> {
// First call
- CompletionStage firstCall = resolver.sendAsync(qm);
+ CompletionStage firstCall =
+ resolver.sendAsync(qm).whenComplete((msg, ex) -> firstCallCompleted.set(true));
+
// Ensure second call was made after first call and uses a different query
startNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(20));
CompletionStage secondCall = resolver.sendAsync(Message.newQuery(qr));
@@ -261,7 +286,6 @@ void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime(
assertEquals(Rcode.NOERROR, result.getHeader().getRcode());
assertEquals(0, result.getHeader().getID());
assertEquals(queryName, result.getQuestion().getName());
- firstCallCompleted.set(true);
})));
Future.fromCompletionStage(secondCall)
@@ -302,10 +326,12 @@ private Future setupServer(
int thisRequestNum = requestCount.incrementAndGet();
int count = concurrentRequests.incrementAndGet();
if (count > maxConcurrentRequests) {
- context.failNow("Concurrent requests exceeded");
+ context.failNow(
+ "Concurrent requests exceeded: " + count + " > " + maxConcurrentRequests);
return;
}
+ httpRequest.endHandler(v -> concurrentRequests.decrementAndGet());
httpRequest.bodyHandler(
body -> {
context.verify(
@@ -332,15 +358,12 @@ private Future setupServer(
&& (thisRequestNum == 1 || allRequestsUseTimeout)) {
vertx.setTimer(
serverProcessingTime.toMillis(),
- timer -> {
- concurrentRequests.decrementAndGet();
- httpRequest
- .response()
- .setStatusCode(statusCode)
- .end(Buffer.buffer(dnsResponseCopy.toWire()));
- });
+ timer ->
+ httpRequest
+ .response()
+ .setStatusCode(statusCode)
+ .end(Buffer.buffer(dnsResponseCopy.toWire())));
} else {
- concurrentRequests.decrementAndGet();
httpRequest
.response()
.setStatusCode(statusCode)
diff --git a/src/test/java/org/xbill/DNS/MessageTest.java b/src/test/java/org/xbill/DNS/MessageTest.java
index 54e32d5c..48639a97 100644
--- a/src/test/java/org/xbill/DNS/MessageTest.java
+++ b/src/test/java/org/xbill/DNS/MessageTest.java
@@ -35,6 +35,7 @@
//
package org.xbill.DNS;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -177,7 +178,9 @@ void normalize() throws WireParseException {
response.addRecord(queryRecord, Section.QUESTION);
response.addRecord(queryRecord, Section.ADDITIONAL);
response = response.normalize(query, true);
- assertTrue(response.getSection(Section.ANSWER).isEmpty());
- assertTrue(response.getSection(Section.ADDITIONAL).isEmpty());
+ assertThat(response.getSection(Section.ANSWER)).isEmpty();
+ assertThat(response.getHeader().getCount(Section.ANSWER)).isZero();
+ assertThat(response.getSection(Section.ADDITIONAL)).isEmpty();
+ assertThat(response.getHeader().getCount(Section.ADDITIONAL)).isZero();
}
}
diff --git a/src/test/java/org/xbill/DNS/NioTcpClientTest.java b/src/test/java/org/xbill/DNS/NioTcpClientTest.java
index 5ef34095..a985714b 100644
--- a/src/test/java/org/xbill/DNS/NioTcpClientTest.java
+++ b/src/test/java/org/xbill/DNS/NioTcpClientTest.java
@@ -1,25 +1,36 @@
// SPDX-License-Identifier: BSD-3-Clause
package org.xbill.DNS;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.fail;
+import static org.xbill.DNS.NioClient.SELECTOR_TIMEOUT_PROPERTY;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
+import java.lang.reflect.Field;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
+import java.nio.channels.ClosedChannelException;
+import java.nio.channels.Pipe;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
import java.time.Duration;
import java.util.ArrayList;
+import java.util.ConcurrentModificationException;
import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicReference;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
@@ -29,8 +40,6 @@
@Slf4j
class NioTcpClientTest {
- private static final String SELECTOR_TIMEOUT_PROPERTY = "dnsjava.nio.selector_timeout";
-
@Test
void testCloseWithoutStart() {
assertDoesNotThrow(NioClient::close);
@@ -47,6 +56,136 @@ void testSelectorTimeoutLimits(int timeout) {
}
}
+ /**
+ * Verifies that NioClient.processReadyKeys() does not throw a ConcurrentModificationException
+ * when channels are registered with the selector while processReadyKeys() is iterating over
+ * selector.selectedKeys(). This simulates concurrent modifications that can occur under high
+ * load.
+ *
+ * The test starts the selector thread, registers an initial key, and then rapidly registers
+ * new channels from another thread while making them ready. It fails if a
+ * ConcurrentModificationException is observed during the process.
+ *
+ *
Since this involves concurrency timing, the test may be flaky and not fail consistently,
+ * even if the underlying issue exists.
+ */
+ @Test
+ void testProcessReadyKeysShouldNotThrowConcurrentModificationException() throws Exception {
+ // Speed up selector loop
+ System.setProperty(SELECTOR_TIMEOUT_PROPERTY, "10");
+
+ try {
+ // Start selector thread
+ Selector selector = NioClient.selector();
+
+ // Add initial key to ensure selectedKeys isn't empty
+ Pipe initialPipe = Pipe.open();
+ initialPipe.source().configureBlocking(false);
+ SelectionKey key = initialPipe.source().register(selector, SelectionKey.OP_READ);
+ key.attach(
+ (NioClient.KeyProcessor)
+ readyKey -> {
+ try {
+ // Slow down processing
+ Thread.sleep(10);
+ } catch (InterruptedException ignored) {
+ // ignore
+ }
+ });
+
+ // Watch for unexpected exceptions
+ AtomicReference sawCME = new AtomicReference<>(null);
+ Field selectorThreadField = NioClient.class.getDeclaredField("selectorThread");
+ selectorThreadField.setAccessible(true);
+ ((Thread) selectorThreadField.get(null))
+ .setUncaughtExceptionHandler(
+ (t, e) -> {
+ if (e instanceof ConcurrentModificationException) {
+ // Violation of expected behavior
+ sawCME.set(e);
+ }
+ });
+
+ List allPipes = new ArrayList<>();
+ Queue registrationQueue = new ConcurrentLinkedQueue<>();
+ NioClient.setRegistrationsTask(
+ sel -> {
+ Pipe pipe = registrationQueue.poll();
+ if (pipe != null) {
+ try {
+ SelectionKey sk = pipe.source().register(sel, SelectionKey.OP_READ);
+ sk.attach(
+ (NioClient.KeyProcessor)
+ readyKey -> {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException ignored) {
+ // ignore
+ }
+ });
+ } catch (ClosedChannelException e) {
+ throw new RuntimeException(e);
+ }
+ }
+ },
+ true);
+
+ // Thread that registers new channels and makes some ready
+ Thread t1 =
+ new Thread(
+ () -> {
+ try {
+ for (int i = 0; i < 100; i++) {
+ Pipe pipe = Pipe.open();
+ pipe.source().configureBlocking(false);
+ registrationQueue.add(pipe);
+ allPipes.add(pipe);
+
+ // Make the channel ready to trigger selectedKeys modification
+ if (i % 2 == 0) {
+ pipe.sink().write(ByteBuffer.wrap("x".getBytes()));
+ }
+
+ // Help trigger overlap
+ Thread.sleep(2);
+ }
+ } catch (Exception ignored) {
+ // ignore
+ }
+ });
+
+ // Thread that makes the remaining registrations ready
+ Thread t2 =
+ new Thread(
+ () -> {
+ try {
+ for (int i = 0; i < 100; i++) {
+ // Make the channel ready to trigger selectedKeys modification
+ if (i % 2 != 0) {
+ allPipes.get(i).sink().write(ByteBuffer.wrap("x".getBytes()));
+ }
+
+ // Help trigger overlap
+ Thread.sleep(2);
+ }
+ } catch (Exception ignored) {
+ // ignore
+ }
+ });
+
+ t1.start();
+ t2.start();
+ t1.join(5000);
+ t2.join(5000);
+ NioClient.close();
+
+ // Assume correctness: fail if any exception was observed
+ assertThat(sawCME.get()).isNull();
+ } finally {
+ System.clearProperty(SELECTOR_TIMEOUT_PROPERTY);
+ }
+ }
+
@Test
void testResponseStream() throws InterruptedException, IOException {
try {
@@ -59,8 +198,7 @@ void testResponseStream() throws InterruptedException, IOException {
for (int i = 0; i < q.length; i++) {
q[i] = Message.newQuery(qr);
// This is not actually valid data, but it increases the payload sufficiently to fill the
- // send buffer,
- // forcing NioTcpClient.Transaction#send into the retry
+ // send buffer, forcing NioTcpClient.Transaction#send into the retry
// see https://github.com/dnsjava/dnsjava/issues/357
for (int j = 0; j < 2048; j++) {
q[i].addRecord(
diff --git a/src/test/java/org/xbill/DNS/ResolverTest.java b/src/test/java/org/xbill/DNS/ResolverTest.java
new file mode 100644
index 00000000..365254a8
--- /dev/null
+++ b/src/test/java/org/xbill/DNS/ResolverTest.java
@@ -0,0 +1,47 @@
+// SPDX-License-Identifier: BSD-3-Clause
+package org.xbill.DNS;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.fail;
+
+import java.net.UnknownHostException;
+import java.time.Duration;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import org.junit.jupiter.api.Test;
+
+class ResolverTest {
+ @Test
+ @SuppressWarnings("deprecation")
+ void resolverListenerExceptionUnwrap() throws InterruptedException, UnknownHostException {
+ // 1. Point to a blackhole address from RFC 5737 TEST-NET-1 to ensure a timeout
+ SimpleResolver resolver = new SimpleResolver("192.0.2.1");
+ resolver.setTimeout(Duration.ofSeconds(2));
+
+ Message query =
+ Message.newQuery(
+ Record.newRecord(Name.fromConstantString("example.com."), Type.A, DClass.IN));
+ CountDownLatch latch = new CountDownLatch(1);
+
+ // 2. Use the async method with a listener
+ resolver.sendAsync(
+ query,
+ new ResolverListener() {
+ @Override
+ public void receiveMessage(Object id, Message m) {
+ fail("Received message (should not happen)");
+ latch.countDown();
+ }
+
+ @Override
+ public void handleException(Object id, Exception ex) {
+ // 3. Observe the exception type
+ assertThat(ex).isNotInstanceOf(CompletionException.class);
+ latch.countDown();
+ }
+ });
+
+ latch.await(5, TimeUnit.SECONDS);
+ }
+}
diff --git a/src/test/java/org/xbill/DNS/TSIGTest.java b/src/test/java/org/xbill/DNS/TSIGTest.java
index a83fb6e0..416e8747 100644
--- a/src/test/java/org/xbill/DNS/TSIGTest.java
+++ b/src/test/java/org/xbill/DNS/TSIGTest.java
@@ -2,6 +2,7 @@
package org.xbill.DNS;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
@@ -520,6 +521,19 @@ TCPClient createTcpClient(Duration timeout) throws IOException {
assertEquals(202, handler.getRecords().size());
}
+ @Test
+ void invalidAdditionalCount() {
+ Message q = Message.newQuery(Record.newRecord(Name.root, Type.A, DClass.IN));
+ Message m = new Message();
+ m.addRecord(Record.newRecord(Name.root, Type.A, DClass.IN), Section.QUESTION);
+ m.addRecord(Record.newRecord(Name.root, Type.A, DClass.IN), Section.ANSWER);
+ m.addRecord(
+ Record.newRecord(Name.fromConstantString("example.com."), Type.A, DClass.IN),
+ Section.ADDITIONAL);
+ assertDoesNotThrow(m::getTSIG);
+ assertDoesNotThrow(() -> m.normalize(q).getTSIG());
+ }
+
@Getter
private static class ZoneBuilderAxfrHandler implements ZoneTransferIn.ZoneTransferHandler {
private final List records = new ArrayList<>();
diff --git a/src/test/java/org/xbill/DNS/dnssec/TestBase.java b/src/test/java/org/xbill/DNS/dnssec/TestBase.java
index 60f53b2c..a5745c78 100644
--- a/src/test/java/org/xbill/DNS/dnssec/TestBase.java
+++ b/src/test/java/org/xbill/DNS/dnssec/TestBase.java
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: BSD-3-Clause
package org.xbill.DNS.dnssec;
+import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
import static org.mockito.Mockito.mock;
@@ -126,7 +127,18 @@ private void starting(TestInfo description) {
Message m;
while ((m = messageReader.readMessage(r)) != null) {
+ for (int i = 0; i < 4; i++) {
+ assertThat(m.getHeader().getCount(i))
+ .withFailMessage("Before normalization")
+ .isEqualTo(m.getSection(i).size());
+ }
+
m = m.normalize(Message.newQuery(m.getQuestion()), true);
+ for (int i = 0; i < 4; i++) {
+ assertThat(m.getHeader().getCount(i))
+ .withFailMessage("After normalization")
+ .isEqualTo(m.getSection(i).size());
+ }
queryResponsePairs.put(key(m), m);
}
@@ -286,7 +298,7 @@ protected String getEdeText(Message m) {
.flatMap(
opt ->
opt.getOptions(Code.EDNS_EXTENDED_ERROR).stream()
- .filter(o -> o instanceof ExtendedErrorCodeOption)
+ .filter(ExtendedErrorCodeOption.class::isInstance)
.findFirst()
.map(o -> ((ExtendedErrorCodeOption) o).getText()))
.orElse(null);
diff --git a/src/test/java/org/xbill/DNS/hosts/HostsFileParserTest.java b/src/test/java/org/xbill/DNS/hosts/HostsFileParserTest.java
index 921e0beb..20700a9b 100644
--- a/src/test/java/org/xbill/DNS/hosts/HostsFileParserTest.java
+++ b/src/test/java/org/xbill/DNS/hosts/HostsFileParserTest.java
@@ -5,6 +5,8 @@
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
import java.io.BufferedWriter;
import java.io.IOException;
@@ -18,6 +20,9 @@
import java.nio.file.StandardCopyOption;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileTime;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
import java.util.Optional;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
@@ -72,7 +77,7 @@ void testMissingFileIsEmptyResult() throws IOException {
}
@Test
- void testCacheLookup() throws IOException {
+ void testCacheLookupAfterFileDeleteWithoutChangeChecking() throws IOException {
Path tempHosts = Files.copy(hostsFileWindows, tempDir, StandardCopyOption.REPLACE_EXISTING);
HostsFileParser hostsFileParser = new HostsFileParser(tempHosts, false);
assertEquals(0, hostsFileParser.cacheSize());
@@ -98,6 +103,10 @@ void testFileDeletionClearsCache() throws IOException {
tempDir.resolve("testFileWatcherClearsCache"),
StandardCopyOption.REPLACE_EXISTING);
HostsFileParser hostsFileParser = new HostsFileParser(tempHosts);
+ Clock clock = mock(Clock.class);
+ hostsFileParser.setClock(clock);
+ Instant now = Clock.systemUTC().instant();
+ when(clock.instant()).thenReturn(now);
assertEquals(0, hostsFileParser.cacheSize());
assertEquals(
kubernetesAddress,
@@ -106,6 +115,7 @@ void testFileDeletionClearsCache() throws IOException {
.orElseThrow(() -> new IllegalStateException("Host entry not found")));
assertTrue(hostsFileParser.cacheSize() > 1, "Cache must not be empty");
Files.delete(tempHosts);
+ when(clock.instant()).thenReturn(now.plus(Duration.ofMinutes(6)));
assertEquals(Optional.empty(), hostsFileParser.getAddressForHost(kubernetesName, Type.A));
assertEquals(0, hostsFileParser.cacheSize());
}
@@ -119,6 +129,10 @@ void testFileChangeClearsCache() throws IOException {
StandardCopyOption.REPLACE_EXISTING);
Files.setLastModifiedTime(tempHosts, FileTime.fromMillis(0));
HostsFileParser hostsFileParser = new HostsFileParser(tempHosts);
+ Clock clock = mock(Clock.class);
+ hostsFileParser.setClock(clock);
+ Instant now = Clock.systemUTC().instant();
+ when(clock.instant()).thenReturn(now);
assertEquals(0, hostsFileParser.cacheSize());
assertEquals(
kubernetesAddress,
@@ -134,6 +148,7 @@ void testFileChangeClearsCache() throws IOException {
}
Files.setLastModifiedTime(tempHosts, FileTime.fromMillis(10_0000));
+ when(clock.instant()).thenReturn(now.plus(Duration.ofMinutes(6)));
assertEquals(
InetAddress.getByAddress(testName.toString(), localhostBytes),
hostsFileParser