diff --git a/pom.xml b/pom.xml index 9674ec39..bf65a05e 100644 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ dnsjava dnsjava bundle - 3.6.3 + 3.6.4-SNAPSHOT dnsjava dnsjava is an implementation of DNS in Java. It supports all defined record types (including the DNSSEC types), and unknown types. It can be used for queries, zone transfers, and dynamic updates. It includes a cache @@ -30,7 +30,7 @@ scm:git:https://github.com/dnsjava/dnsjava scm:git:https://github.com/dnsjava/dnsjava https://github.com/dnsjava/dnsjava - v3.6.3 + HEAD @@ -49,14 +49,14 @@ true false - 5.11.4 + 5.13.1 4.11.0 1.7.36 - 1.18.36 - 5.16.0 - 1.79 - 4.5.11 + 1.18.38 + 5.17.0 + 1.81 + 4.5.16 1.7 2.30.0 @@ -66,6 +66,52 @@ + + + + org.apache.maven.plugins + maven-deploy-plugin + 3.1.4 + + + + org.apache.maven.plugins + maven-dependency-plugin + 3.8.1 + + + + org.apache.maven.plugins + maven-release-plugin + 3.1.1 + + + + org.apache.maven.plugins + maven-resources-plugin + 3.3.1 + + + + org.apache.maven.plugins + maven-install-plugin + 3.1.4 + + + + org.codehaus.mojo + build-helper-maven-plugin + 3.6.1 + + + + org.apache.maven.plugins + maven-site-plugin + 3.21.0 + + + + org.codehaus.mojo @@ -111,7 +157,7 @@ org.apache.maven.plugins maven-compiler-plugin - 3.13.0 + 3.14.0 -Xlint:all,-serial,-processing @@ -237,7 +283,7 @@ org.apache.maven.plugins maven-surefire-plugin - 3.5.2 + 3.5.3 3 @@ -254,7 +300,7 @@ org.jacoco jacoco-maven-plugin - 0.8.12 + 0.8.13 prepare-agent @@ -292,7 +338,7 @@ com.github.siom79.japicmp japicmp-maven-plugin - 0.23.0 + 0.23.1 @@ -325,6 +371,12 @@ true PATCH + + SUPERCLASS_ADDED + true + true + PATCH + ANNOTATION_DEPRECATED_ADDED PATCH @@ -422,16 +474,10 @@ - - org.apache.maven.plugins - maven-deploy-plugin - 3.1.3 - - org.apache.maven.plugins maven-clean-plugin - 3.4.0 + 3.5.0 @@ -445,12 +491,6 @@ - - org.apache.maven.plugins - maven-site-plugin - 3.21.0 - - org.sonatype.plugins nexus-staging-maven-plugin @@ -463,59 +503,6 @@ - - org.codehaus.mojo - animal-sniffer-maven-plugin - 1.24 - - - net.sf.androidscents.signature - android-api-level-26 - 8.0.0_r2 - - - javax.naming.NamingException - javax.naming.directory.* - sun.net.spi.nameservice.* - java.net.spi.* - - - - - org.ow2.asm - asm - 9.7.1 - - - - - animal-sniffer - test - - check - - - - - - - org.apache.maven.plugins - maven-resources-plugin - 3.3.1 - - - - org.apache.maven.plugins - maven-install-plugin - 3.1.3 - - - - org.codehaus.mojo - build-helper-maven-plugin - 3.6.0 - - org.apache.maven.plugins maven-enforcer-plugin @@ -627,7 +614,7 @@ org.assertj assertj-core - 3.27.0 + 3.27.3 test @@ -645,7 +632,7 @@ net.bytebuddy byte-buddy-agent - 1.15.11 + 1.17.6 test @@ -676,7 +663,7 @@ commons-io commons-io - 2.18.0 + 2.19.0 test @@ -703,6 +690,60 @@ ${target.jdk} + + + org.codehaus.mojo + animal-sniffer-maven-plugin + 1.24 + + + com.toasttab.android + gummy-bears-api-26 + 0.12.0 + + + javax.naming.NamingException + javax.naming.directory.* + sun.net.spi.nameservice.* + java.net.spi.* + + + + + org.ow2.asm + asm + 9.8 + + + + + animal-sniffer + test + + check + + + + + + + org.jacoco + jacoco-maven-plugin + + + report + verify + + report + + + + META-INF/** + + + + + @@ -714,7 +755,7 @@ - 5.14.2 + 5.18.0 + false + ${project.build.outputDirectory}/META-INF/versions/11 - ${project.build.outputDirectory}/META-INF/versions/11 + ${project.build.outputDirectory} - - - - - - java11-not-idea - - false - [11,) - - !idea.version - - - - - org.codehaus.mojo - build-helper-maven-plugin - - - add-java11-source - generate-sources - - add-source - - - - src/main/java11 - - - - + org.jacoco + jacoco-maven-plugin + + + org/xbill/DNS/AsyncSemaphore* + org/xbill/DNS/DohResolver* + + - - java11-idea + + java11-not-idea false - [11,18) + [11,) - idea.version + !idea.version @@ -927,14 +955,40 @@ @{argLine} --add-opens java.base/sun.net.dns=ALL-UNNAMED -javaagent:${net.bytebuddy:byte-buddy-agent:jar} + -javaagent:${org.mockito:mockito-core:jar} + + + false + ${project.build.outputDirectory}/META-INF/versions/18 ${project.build.outputDirectory}/META-INF/versions/11 - ${project.build.outputDirectory}/META-INF/versions/18 + ${project.build.outputDirectory} + + + + + + java18-not-idea + + false + [18,) + + !idea.version + + + + org.codehaus.mojo build-helper-maven-plugin diff --git a/sonar-project.properties b/sonar-project.properties index 3fcdfad1..cd897637 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -3,6 +3,7 @@ sonar.organization=dnsjava sonar.host.url=https://sonarcloud.io sonar.java.source=8 sonar.coverage.jacoco.xmlReportPaths=target/site/jacoco/jacoco.xml +sonar.scanner.skipJreProvisioning=true sonar.issue.ignore.multicriteria=S106,S107,S120,S1948,S2160 diff --git a/src/main/java/org/xbill/DNS/AsyncSemaphore.java b/src/main/java/org/xbill/DNS/AsyncSemaphore.java index 456b8ad3..0d704913 100644 --- a/src/main/java/org/xbill/DNS/AsyncSemaphore.java +++ b/src/main/java/org/xbill/DNS/AsyncSemaphore.java @@ -6,6 +6,7 @@ import java.util.Queue; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; import lombok.extern.slf4j.Slf4j; @@ -13,34 +14,50 @@ final class AsyncSemaphore { private final Queue> queue = new ArrayDeque<>(); private final Permit singletonPermit = new Permit(); + private final String name; private volatile int permits; final class Permit { - public void release() { + public void release(int id, Executor executor) { synchronized (queue) { CompletableFuture next = queue.poll(); if (next == null) { permits++; + log.trace("{} permit released id={}, available={}", name, id, permits); } else { - next.complete(this); + log.trace("{} permit released id={}, available={}, immediate next", name, id, permits); + executor.execute(() -> next.complete(this)); } } } } - AsyncSemaphore(int permits) { + AsyncSemaphore(int permits, String name) { this.permits = permits; + this.name = name; + log.debug("Using Java 8 implementation for {}", name); } - CompletionStage acquire(Duration timeout) { + CompletionStage acquire(Duration timeout, int id, Executor executor) { synchronized (queue) { if (permits > 0) { permits--; + log.trace("{} permit acquired id={}, available={}", name, id, permits); return CompletableFuture.completedFuture(singletonPermit); } else { TimeoutCompletableFuture f = new TimeoutCompletableFuture<>(); f.compatTimeout(timeout.toNanos(), TimeUnit.NANOSECONDS) - .whenComplete((result, ex) -> queue.remove(f)); + .whenCompleteAsync( + (result, ex) -> { + synchronized (queue) { + if (ex != null) { + log.trace("{} permit timed out id={}, available={}", name, id, permits); + } + queue.remove(f); + } + }, + executor); + log.trace("{} permit queued id={}, available={}", name, id, permits); queue.add(f); return f; } diff --git a/src/main/java/org/xbill/DNS/DohResolver.java b/src/main/java/org/xbill/DNS/DohResolver.java index b6707bbd..fa150cc0 100644 --- a/src/main/java/org/xbill/DNS/DohResolver.java +++ b/src/main/java/org/xbill/DNS/DohResolver.java @@ -5,8 +5,6 @@ import java.io.EOFException; import java.io.IOException; import java.io.InputStream; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; import java.net.HttpURLConnection; import java.net.SocketTimeoutException; import java.net.URI; @@ -14,25 +12,19 @@ import java.security.NoSuchAlgorithmException; import java.time.Duration; import java.time.temporal.ChronoUnit; -import java.util.Collections; import java.util.List; -import java.util.Map; -import java.util.WeakHashMap; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.Executor; -import java.util.concurrent.ForkJoinPool; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.TimeoutException; import java.util.function.Function; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLSocketFactory; -import lombok.SneakyThrows; import lombok.Value; import lombok.extern.slf4j.Slf4j; -import org.xbill.DNS.AsyncSemaphore.Permit; -import org.xbill.DNS.utils.base64; /** * Proof-of-concept DNS over HTTP (DoH) @@ -51,141 +43,16 @@ * @since 3.0 */ @Slf4j -public final class DohResolver implements Resolver { - private static final boolean USE_HTTP_CLIENT; - private static final Map httpClients = - Collections.synchronizedMap(new WeakHashMap<>()); +public final class DohResolver extends DohResolverCommon { private final SSLSocketFactory sslSocketFactory; - private static Object defaultHttpRequestBuilder; - private static Method publisherOfByteArrayMethod; - private static Method requestBuilderTimeoutMethod; - private static Method requestBuilderCopyMethod; - private static Method requestBuilderUriMethod; - private static Method requestBuilderBuildMethod; - private static Method requestBuilderPostMethod; - - private static Method httpClientNewBuilderMethod; - private static Method httpClientBuilderTimeoutMethod; - private static Method httpClientBuilderExecutorMethod; - private static Method httpClientBuilderBuildMethod; - private static Method httpClientSendAsyncMethod; - - private static Method byteArrayBodyPublisherMethod; - private static Method httpResponseBodyMethod; - private static Method httpResponseStatusCodeMethod; - - private boolean usePost = false; - private Duration timeout = Duration.ofSeconds(5); - private String uriTemplate; - private final Duration idleConnectionTimeout; - private OPTRecord queryOPT = new OPTRecord(0, 0, 0); - private TSIG tsig; - private Executor defaultExecutor = ForkJoinPool.commonPool(); - - /** - * Maximum concurrent HTTP/2 streams or HTTP/1.1 connections. - * - *

rfc7540#section-6.5.2 recommends a minimum of 100 streams for HTTP/2. - */ - private final AsyncSemaphore maxConcurrentRequests; - - private final AtomicLong lastRequest = new AtomicLong(0); - private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1); - - private static final String APPLICATION_DNS_MESSAGE = "application/dns-message"; - - static { - boolean initSuccess = false; - if (!System.getProperty("java.version").startsWith("1.")) { - try { - Class httpClientBuilderClass = Class.forName("java.net.http.HttpClient$Builder"); - Class httpClientClass = Class.forName("java.net.http.HttpClient"); - Class httpVersionEnum = Class.forName("java.net.http.HttpClient$Version"); - Class httpRequestBuilderClass = Class.forName("java.net.http.HttpRequest$Builder"); - Class httpRequestClass = Class.forName("java.net.http.HttpRequest"); - Class bodyPublishersClass = Class.forName("java.net.http.HttpRequest$BodyPublishers"); - Class bodyPublisherClass = Class.forName("java.net.http.HttpRequest$BodyPublisher"); - Class httpResponseClass = Class.forName("java.net.http.HttpResponse"); - Class bodyHandlersClass = Class.forName("java.net.http.HttpResponse$BodyHandlers"); - Class bodyHandlerClass = Class.forName("java.net.http.HttpResponse$BodyHandler"); - - // HttpClient.Builder - httpClientBuilderTimeoutMethod = - httpClientBuilderClass.getDeclaredMethod("connectTimeout", Duration.class); - httpClientBuilderExecutorMethod = - httpClientBuilderClass.getDeclaredMethod("executor", Executor.class); - httpClientBuilderBuildMethod = httpClientBuilderClass.getDeclaredMethod("build"); - - // HttpClient - httpClientNewBuilderMethod = httpClientClass.getDeclaredMethod("newBuilder"); - httpClientSendAsyncMethod = - httpClientClass.getDeclaredMethod("sendAsync", httpRequestClass, bodyHandlerClass); - - // HttpRequestBuilder - Method requestBuilderHeaderMethod = - httpRequestBuilderClass.getDeclaredMethod("header", String.class, String.class); - Method requestBuilderVersionMethod = - httpRequestBuilderClass.getDeclaredMethod("version", httpVersionEnum); - requestBuilderTimeoutMethod = - httpRequestBuilderClass.getDeclaredMethod("timeout", Duration.class); - requestBuilderUriMethod = httpRequestBuilderClass.getDeclaredMethod("uri", URI.class); - requestBuilderCopyMethod = httpRequestBuilderClass.getDeclaredMethod("copy"); - requestBuilderBuildMethod = httpRequestBuilderClass.getDeclaredMethod("build"); - requestBuilderPostMethod = - httpRequestBuilderClass.getDeclaredMethod("POST", bodyPublisherClass); - - // HttpRequest - Method requestBuilderNewBuilderMethod = httpRequestClass.getDeclaredMethod("newBuilder"); - - // BodyPublishers - publisherOfByteArrayMethod = - bodyPublishersClass.getDeclaredMethod("ofByteArray", byte[].class); - - // BodyPublisher - byteArrayBodyPublisherMethod = bodyHandlersClass.getDeclaredMethod("ofByteArray"); - - // HttpResponse - httpResponseBodyMethod = httpResponseClass.getDeclaredMethod("body"); - httpResponseStatusCodeMethod = httpResponseClass.getDeclaredMethod("statusCode"); - - // defaultHttpRequestBuilder = HttpRequest.newBuilder(); - // defaultHttpRequestBuilder.version(HttpClient.Version.HTTP_2); - // defaultHttpRequestBuilder.header("Content-Type", "application/dns-message"); - // defaultHttpRequestBuilder.header("Accept", "application/dns-message"); - defaultHttpRequestBuilder = requestBuilderNewBuilderMethod.invoke(null); - @SuppressWarnings({"unchecked", "rawtypes"}) - Enum http2Version = Enum.valueOf((Class) httpVersionEnum, "HTTP_2"); - requestBuilderVersionMethod.invoke(defaultHttpRequestBuilder, http2Version); - requestBuilderHeaderMethod.invoke( - defaultHttpRequestBuilder, "Content-Type", APPLICATION_DNS_MESSAGE); - requestBuilderHeaderMethod.invoke( - defaultHttpRequestBuilder, "Accept", APPLICATION_DNS_MESSAGE); - initSuccess = true; - } catch (ClassNotFoundException - | NoSuchMethodException - | IllegalAccessException - | InvocationTargetException e) { - // fallback to Java 8 - log.warn("Java >= 11 detected, but HttpRequest not available"); - } - } - - USE_HTTP_CLIENT = initSuccess; - } - - // package-visible for testing - long getNanoTime() { - return System.nanoTime(); - } - /** * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). * * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */ public DohResolver(String uriTemplate) { - this(uriTemplate, 100, Duration.ofMinutes(2)); + this(uriTemplate, 100, Duration.ZERO); } /** @@ -201,22 +68,9 @@ public DohResolver(String uriTemplate) { */ public DohResolver( String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) { - this.uriTemplate = uriTemplate; - this.idleConnectionTimeout = idleConnectionTimeout; - if (maxConcurrentRequests <= 0) { - throw new IllegalArgumentException("maxConcurrentRequests must be > 0"); - } - if (!USE_HTTP_CLIENT) { - try { - int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5")); - if (maxConcurrentRequests > javaMaxConn) { - maxConcurrentRequests = javaMaxConn; - } - } catch (NumberFormatException nfe) { - // well, use what we got - } - } - this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests); + super(uriTemplate, maxConcurrentRequests); + + log.debug("Using Java 8 implementation"); try { sslSocketFactory = SSLContext.getDefault().getSocketFactory(); } catch (NoSuchAlgorithmException e) { @@ -224,45 +78,6 @@ public DohResolver( } } - @SneakyThrows - private Object getHttpClient(Executor executor) { - return httpClients.computeIfAbsent( - executor, - key -> { - try { - // return HttpClient.newBuilder() - // .connectTimeout(timeout). - // .executor(executor) - // .build(); - Object httpClientBuilder = httpClientNewBuilderMethod.invoke(null); - httpClientBuilderTimeoutMethod.invoke(httpClientBuilder, timeout); - httpClientBuilderExecutorMethod.invoke(httpClientBuilder, key); - return httpClientBuilderBuildMethod.invoke(httpClientBuilder); - } catch (IllegalAccessException | InvocationTargetException e) { - log.warn("Could not create a HttpClient with for Executor {}", key, e); - return null; - } - }); - } - - /** Not implemented. Specify the port in {@link #setUriTemplate(String)} if required. */ - @Override - public void setPort(int port) { - // Not implemented, port is part of the URI - } - - /** Not implemented. */ - @Override - public void setTCP(boolean flag) { - // Not implemented, HTTP is always TCP - } - - /** Not implemented. */ - @Override - public void setIgnoreTruncation(boolean flag) { - // Not implemented, protocol uses TCP and doesn't have truncation - } - /** * Sets the EDNS information on outgoing messages. * @@ -272,87 +87,85 @@ public void setIgnoreTruncation(boolean flag) { * @param options EDNS options to be set in the OPT record */ @Override + @SuppressWarnings("java:S1185") // required for source- and binary compatibility public void setEDNS(int version, int payloadSize, int flags, List options) { - switch (version) { - case -1: - queryOPT = null; - break; - - case 0: - queryOPT = new OPTRecord(0, 0, version, flags, options); - break; - - default: - throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable"); - } - } - - @Override - public void setTSIGKey(TSIG key) { - this.tsig = key; - } - - @Override - public void setTimeout(Duration timeout) { - this.timeout = timeout; - httpClients.clear(); - } - - @Override - public Duration getTimeout() { - return timeout; + // required for source- and binary compatibility + super.setEDNS(version, payloadSize, flags, options); } @Override + @SuppressWarnings("java:S1185") // required for source- and binary compatibility public CompletionStage sendAsync(Message query) { - return sendAsync(query, defaultExecutor); + // required for source- and binary compatibility + return this.sendAsync(query, defaultExecutor); } @Override public CompletionStage sendAsync(Message query, Executor executor) { - if (USE_HTTP_CLIENT) { - return sendAsync11(query, executor); - } - - return sendAsync8(query, executor); - } - - private CompletionStage sendAsync8(final Message query, Executor executor) { byte[] queryBytes = prepareQuery(query).toWire(); String url = getUrl(queryBytes); long startTime = getNanoTime(); - return maxConcurrentRequests - .acquire(timeout) - .handleAsync( - (permit, ex) -> { - if (ex != null) { - return this.timeoutFailedFuture(query, ex); - } else { - try { - SendAndGetMessageBytesResponse result = - sendAndGetMessageBytes(url, queryBytes, startTime); - Message response; - if (result.rc == Rcode.NOERROR) { - response = new Message(result.responseBytes); - verifyTSIG(query, response, result.responseBytes, tsig); + int queryId = query.getHeader().getID(); + + CompletableFuture f = + maxConcurrentRequests + .acquire(timeout, queryId, executor) + .handleAsync( + (permit, ex) -> { + if (ex != null) { + return this.timeoutFailedFuture( + query, "could not acquire lock to send request", ex); } else { - response = new Message(0); - response.getHeader().setRcode(result.rc); + try { + SendAndGetMessageBytesResponse result = + sendAndGetMessageBytes(url, queryBytes, startTime); + Message response; + if (result.rc == Rcode.NOERROR) { + response = new Message(result.responseBytes); + verifyTSIG(query, response, result.responseBytes, tsig); + } else { + response = new Message(0); + response.getHeader().setRcode(result.rc); + } + + response.setResolver(this); + return CompletableFuture.completedFuture(response); + } catch (SocketTimeoutException e) { + return this.timeoutFailedFuture(query, e); + } catch (IOException | URISyntaxException e) { + return this.failedFuture(e); + } finally { + permit.release(queryId, executor); + } } + }, + executor) + .thenCompose(Function.identity()) + .toCompletableFuture(); - response.setResolver(this); - return CompletableFuture.completedFuture(response); - } catch (SocketTimeoutException e) { - return this.timeoutFailedFuture(query, e); - } catch (IOException | URISyntaxException e) { - return this.failedFuture(e); - } finally { - permit.release(); - } + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + return TimeoutCompletableFuture.compatTimeout( + f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS) + .exceptionally( + ex -> { + if (ex instanceof TimeoutException) { + throw new CompletionException( + new TimeoutException( + "Query " + + queryId + + " for " + + query.getQuestion().getName() + + "/" + + Type.string(query.getQuestion().getType()) + + " timed out in remaining " + + remainingTimeout.toMillis() + + "ms")); + } else if (ex instanceof CompletionException) { + throw (CompletionException) ex; } - }, - executor) - .thenCompose(Function.identity()); + + throw new CompletionException(ex); + }); } @Value @@ -367,15 +180,28 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes( if (conn instanceof HttpsURLConnection) { ((HttpsURLConnection) conn).setSSLSocketFactory(sslSocketFactory); } - - Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - conn.setConnectTimeout((int) remainingTimeout.toMillis()); - conn.setReadTimeout((int) remainingTimeout.toMillis()); conn.setRequestMethod(usePost ? "POST" : "GET"); conn.setRequestProperty("Content-Type", APPLICATION_DNS_MESSAGE); conn.setRequestProperty("Accept", APPLICATION_DNS_MESSAGE); + + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + throw new SocketTimeoutException("No time left to connect"); + } + + conn.setConnectTimeout((int) remainingTimeout.toMillis()); if (usePost) { conn.setDoOutput(true); + } + + conn.connect(); + remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + throw new SocketTimeoutException("No time left to request data"); + } + + conn.setReadTimeout((int) remainingTimeout.toMillis()); + if (usePost) { conn.getOutputStream().write(queryBytes); } @@ -395,13 +221,23 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes( while ((r = is.read(responseBytes, offset, responseBytes.length - offset)) > 0) { offset += r; remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - if (remainingTimeout.isNegative()) { - throw new SocketTimeoutException(); + + // Don't throw if we just received all data + if (offset != responseBytes.length + && (remainingTimeout.isNegative() || remainingTimeout.isZero())) { + throw new SocketTimeoutException( + "Timed out waiting for response data, got " + + offset + + " of " + + responseBytes.length + + " expected bytes"); } } + if (offset < responseBytes.length) { throw new EOFException("Could not read expected content length"); } + return new SendAndGetMessageBytesResponse(Rcode.NOERROR, responseBytes); } else { try (ByteArrayOutputStream bos = new ByteArrayOutputStream()) { @@ -409,8 +245,9 @@ private SendAndGetMessageBytesResponse sendAndGetMessageBytes( int r; while ((r = is.read(buffer, 0, buffer.length)) > 0) { remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - if (remainingTimeout.isNegative()) { - throw new SocketTimeoutException(); + if (remainingTimeout.isNegative() || remainingTimeout.isZero()) { + throw new SocketTimeoutException( + "Timed out waiting for response data, got " + bos.size() + " bytes so far"); } bos.write(buffer, 0, r); } @@ -436,275 +273,10 @@ private void discardStream(InputStream es) throws IOException { } } - private CompletionStage sendAsync11(final Message query, Executor executor) { - long startTime = getNanoTime(); - byte[] queryBytes = prepareQuery(query).toWire(); - String url = getUrl(queryBytes); - - // var requestBuilder = defaultHttpRequestBuilder.copy(); - // requestBuilder.uri(URI.create(url)); - Object requestBuilder; - try { - requestBuilder = requestBuilderCopyMethod.invoke(defaultHttpRequestBuilder); - requestBuilderUriMethod.invoke(requestBuilder, URI.create(url)); - if (usePost) { - // requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(queryBytes)); - requestBuilderPostMethod.invoke( - requestBuilder, publisherOfByteArrayMethod.invoke(null, queryBytes)); - } - } catch (IllegalAccessException | InvocationTargetException e) { - return failedFuture(e); - } - - // check if this request needs to be done synchronously because of HttpClient's stupidity to - // not use the connection pool for HTTP/2 until one connection is successfully established, - // which could lead to hundreds of connections (and threads with the default executor) - Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - return initialRequestLock - .acquire(remainingTimeout) - .handle( - (initialRequestPermit, initialRequestEx) -> { - if (initialRequestEx != null) { - return this.timeoutFailedFuture(query, initialRequestEx); - } else { - return sendAsync11WithInitialRequestPermit( - query, executor, startTime, requestBuilder, initialRequestPermit); - } - }) - .thenCompose(Function.identity()); - } - - private CompletionStage sendAsync11WithInitialRequestPermit( - Message query, - Executor executor, - long startTime, - Object requestBuilder, - Permit initialRequestPermit) { - long lastRequestTime = lastRequest.get(); - boolean isInitialRequest = idleConnectionTimeout.toNanos() > getNanoTime() - lastRequestTime; - if (!isInitialRequest) { - initialRequestPermit.release(); - } - - // check if we already exceeded the query timeout while checking the initial connection - Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - if (remainingTimeout.isNegative()) { - if (isInitialRequest) { - initialRequestPermit.release(); - } - return timeoutFailedFuture(query, null); - } - - // Lock a HTTP/2 stream. Another stupidity of HttpClient to not simply queue the - // request, but fail with an IOException which also CLOSES the connection... *facepalm* - return maxConcurrentRequests - .acquire(remainingTimeout) - .handle( - (maxConcurrentRequestPermit, maxConcurrentRequestEx) -> { - if (maxConcurrentRequestEx != null) { - if (isInitialRequest) { - initialRequestPermit.release(); - } - return this.timeoutFailedFuture(query, maxConcurrentRequestEx); - } else { - return sendAsync11WithConcurrentRequestPermit( - query, - executor, - startTime, - requestBuilder, - initialRequestPermit, - isInitialRequest, - maxConcurrentRequestPermit); - } - }) - .thenCompose(Function.identity()); - } - - private CompletionStage sendAsync11WithConcurrentRequestPermit( - Message query, - Executor executor, - long startTime, - Object requestBuilder, - Permit initialRequestPermit, - boolean isInitialRequest, - Permit maxConcurrentRequestPermit) { - // check if the stream lock acquisition took too long - Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); - if (remainingTimeout.isNegative()) { - if (isInitialRequest) { - initialRequestPermit.release(); - } - maxConcurrentRequestPermit.release(); - return timeoutFailedFuture(query, null); - } - - // var httpRequest = requestBuilder.timeout(remainingTimeout).build(); - // var bodyHandler = HttpResponse.BodyHandlers.ofByteArray(); - // return getHttpClient(executor).sendAsync(httpRequest, bodyHandler) - try { - Object httpClient = getHttpClient(executor); - requestBuilderTimeoutMethod.invoke(requestBuilder, remainingTimeout); - Object httpRequest = requestBuilderBuildMethod.invoke(requestBuilder); - Object bodyHandler = byteArrayBodyPublisherMethod.invoke(null); - CompletableFuture f = - ((CompletableFuture) - httpClientSendAsyncMethod.invoke(httpClient, httpRequest, bodyHandler)) - .whenComplete( - (result, ex) -> { - if (ex == null) { - lastRequest.set(startTime); - } - maxConcurrentRequestPermit.release(); - if (isInitialRequest) { - initialRequestPermit.release(); - } - }) - .handleAsync( - (response, ex) -> { - if (ex != null) { - if (ex.getCause().getClass().getSimpleName().equals("HttpTimeoutException")) { - return this.timeoutFailedFuture(query, ex.getCause()); - } else { - return this.failedFuture(ex); - } - } else { - try { - Message responseMessage; - // int rc = response.statusCode(); - int rc = (int) httpResponseStatusCodeMethod.invoke(response); - if (rc >= 200 && rc < 300) { - // byte[] responseBytes = response.body(); - byte[] responseBytes = (byte[]) httpResponseBodyMethod.invoke(response); - responseMessage = new Message(responseBytes); - verifyTSIG(query, responseMessage, responseBytes, tsig); - } else { - responseMessage = new Message(); - responseMessage.getHeader().setRcode(Rcode.SERVFAIL); - } - - responseMessage.setResolver(this); - return CompletableFuture.completedFuture(responseMessage); - } catch (IOException | IllegalAccessException | InvocationTargetException e) { - return this.failedFuture(e); - } - } - }, - executor) - .thenCompose(Function.identity()); - return TimeoutCompletableFuture.compatTimeout( - f, remainingTimeout.toMillis(), TimeUnit.MILLISECONDS); - } catch (IllegalAccessException | InvocationTargetException e) { - return failedFuture(e); - } - } - - private CompletableFuture failedFuture(Throwable e) { + @Override + protected CompletableFuture failedFuture(Throwable e) { CompletableFuture f = new CompletableFuture<>(); f.completeExceptionally(e); return f; } - - private CompletableFuture timeoutFailedFuture(Message query, Throwable inner) { - return failedFuture( - new IOException( - "Query " - + query.getHeader().getID() - + " for " - + query.getQuestion().getName() - + "/" - + Type.string(query.getQuestion().getType()) - + " timed out", - inner)); - } - - private String getUrl(byte[] queryBytes) { - String url = uriTemplate; - if (!usePost) { - url += "?dns=" + base64.toString(queryBytes, true); - } - return url; - } - - private Message prepareQuery(Message query) { - Message preparedQuery = query.clone(); - preparedQuery.getHeader().setID(0); - if (queryOPT != null && preparedQuery.getOPT() == null) { - preparedQuery.addRecord(queryOPT, Section.ADDITIONAL); - } - - if (tsig != null) { - tsig.apply(preparedQuery, null); - } - - return preparedQuery; - } - - private void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) { - if (tsig == null) { - return; - } - - int error = tsig.verify(response, b, query.getGeneratedTSIG()); - log.debug( - "TSIG verify for query {}, {}/{}: {}", - query.getHeader().getID(), - query.getQuestion().getName(), - Type.string(query.getQuestion().getType()), - Rcode.TSIGstring(error)); - } - - /** Returns {@code true} if the HTTP method POST to resolve, {@code false} if GET is used. */ - public boolean isUsePost() { - return usePost; - } - - /** - * Sets the HTTP method to use for resolving. - * - * @param usePost {@code true} to use POST, {@code false} to use GET (the default). - */ - public void setUsePost(boolean usePost) { - this.usePost = usePost; - } - - /** Gets the current URI used for resolving. */ - public String getUriTemplate() { - return uriTemplate; - } - - /** Sets the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */ - public void setUriTemplate(String uriTemplate) { - this.uriTemplate = uriTemplate; - } - - /** - * Gets the default {@link Executor} for request handling, defaults to {@link - * ForkJoinPool#commonPool()}. - * - * @since 3.3 - * @deprecated not applicable if {@link #sendAsync(Message, Executor)} is used. - */ - @Deprecated - public Executor getExecutor() { - return defaultExecutor; - } - - /** - * Sets the default {@link Executor} for request handling. - * - * @param executor The new {@link Executor}, can be {@code null} (which is equivalent to {@link - * ForkJoinPool#commonPool()}). - * @since 3.3 - * @deprecated Use {@link #sendAsync(Message, Executor)}. - */ - @Deprecated - public void setExecutor(Executor executor) { - this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor; - httpClients.clear(); - } - - @Override - public String toString() { - return "DohResolver {" + (usePost ? "POST " : "GET ") + uriTemplate + "}"; - } } diff --git a/src/main/java/org/xbill/DNS/DohResolverCommon.java b/src/main/java/org/xbill/DNS/DohResolverCommon.java new file mode 100644 index 00000000..0fb8ac07 --- /dev/null +++ b/src/main/java/org/xbill/DNS/DohResolverCommon.java @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; +import java.util.concurrent.ForkJoinPool; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicLong; +import lombok.extern.slf4j.Slf4j; +import org.xbill.DNS.utils.base64; + +@Slf4j +abstract class DohResolverCommon implements Resolver { + /** + * Maximum concurrent HTTP/2 streams or HTTP/1.1 connections. + * + *

rfc7540#section-6.5.2 recommends a minimum of 100 streams for HTTP/2. + */ + protected final AsyncSemaphore maxConcurrentRequests; + + protected final AtomicLong lastRequest = new AtomicLong(0); + + protected static final String APPLICATION_DNS_MESSAGE = "application/dns-message"; + + protected boolean usePost = false; + protected Duration timeout = Duration.ofSeconds(5); + protected String uriTemplate; + protected OPTRecord queryOPT = new OPTRecord(0, 0, 0); + protected TSIG tsig; + protected Executor defaultExecutor = ForkJoinPool.commonPool(); + + // package-visible for testing + long getNanoTime() { + return System.nanoTime(); + } + + /** + * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). + * + * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} + * @param maxConcurrentRequests Maximum concurrent HTTP/2 streams for Java 11+ or HTTP/1.1 + * connections for Java 8. On Java 8 this cannot exceed the system property {@code + * http.maxConnections}. + */ + protected DohResolverCommon(String uriTemplate, int maxConcurrentRequests) { + this.uriTemplate = uriTemplate; + if (maxConcurrentRequests <= 0) { + throw new IllegalArgumentException("maxConcurrentRequests must be > 0"); + } + + try { + int javaMaxConn = Integer.parseInt(System.getProperty("http.maxConnections", "5")); + if (maxConcurrentRequests > javaMaxConn) { + maxConcurrentRequests = javaMaxConn; + } + } catch (NumberFormatException nfe) { + // well, use what we got + } + + this.maxConcurrentRequests = new AsyncSemaphore(maxConcurrentRequests, "concurrent"); + } + + /** Not implemented. Specify the port in {@link #setUriTemplate(String)} if required. */ + @Override + public void setPort(int port) { + // Not implemented, port is part of the URI + } + + /** Not implemented. */ + @Override + public void setTCP(boolean flag) { + // Not implemented, HTTP is always TCP + } + + /** Not implemented. */ + @Override + public void setIgnoreTruncation(boolean flag) { + // Not implemented, protocol uses TCP and doesn't have truncation + } + + /** + * Sets the EDNS information on outgoing messages. + * + * @param version The EDNS version to use. 0 indicates EDNS0 and -1 indicates no EDNS. + * @param payloadSize ignored + * @param flags EDNS extended flags to be set in the OPT record. + * @param options EDNS options to be set in the OPT record + */ + @Override + public void setEDNS(int version, int payloadSize, int flags, List options) { + switch (version) { + case -1: + queryOPT = null; + break; + + case 0: + queryOPT = new OPTRecord(0, 0, version, flags, options); + break; + + default: + throw new IllegalArgumentException("invalid EDNS version - must be 0 or -1 to disable"); + } + } + + @Override + public void setTSIGKey(TSIG key) { + this.tsig = key; + } + + @Override + public void setTimeout(Duration timeout) { + this.timeout = timeout; + } + + @Override + public Duration getTimeout() { + return timeout; + } + + protected String getUrl(byte[] queryBytes) { + String url = uriTemplate; + if (!usePost) { + url += "?dns=" + base64.toString(queryBytes, true); + } + return url; + } + + protected Message prepareQuery(Message query) { + Message preparedQuery = query.clone(); + preparedQuery.getHeader().setID(0); + if (queryOPT != null && preparedQuery.getOPT() == null) { + preparedQuery.addRecord(queryOPT, Section.ADDITIONAL); + } + + if (tsig != null) { + tsig.apply(preparedQuery, null); + } + + return preparedQuery; + } + + protected void verifyTSIG(Message query, Message response, byte[] b, TSIG tsig) { + if (tsig == null) { + return; + } + + int error = tsig.verify(response, b, query.getGeneratedTSIG()); + log.debug( + "TSIG verify for query {}, {}/{}: {}", + query.getHeader().getID(), + query.getQuestion().getName(), + Type.string(query.getQuestion().getType()), + Rcode.TSIGstring(error)); + } + + /** Returns {@code true} if the HTTP method POST to resolve, {@code false} if GET is used. */ + public boolean isUsePost() { + return usePost; + } + + /** + * Sets the HTTP method to use for resolving. + * + * @param usePost {@code true} to use POST, {@code false} to use GET (the default). + */ + public void setUsePost(boolean usePost) { + this.usePost = usePost; + } + + /** Gets the current URI used for resolving. */ + public String getUriTemplate() { + return uriTemplate; + } + + /** Sets the URI to use for resolving, e.g. {@code https://dns.google/dns-query} */ + public void setUriTemplate(String uriTemplate) { + this.uriTemplate = uriTemplate; + } + + /** + * Gets the default {@link Executor} for request handling, defaults to {@link + * ForkJoinPool#commonPool()}. + * + * @since 3.3 + * @deprecated not applicable if {@link #sendAsync(Message, Executor)} is used. + */ + @Deprecated + public Executor getExecutor() { + return defaultExecutor; + } + + /** + * Sets the default {@link Executor} for request handling. + * + * @param executor The new {@link Executor}, can be {@code null} (which is equivalent to {@link + * ForkJoinPool#commonPool()}). + * @since 3.3 + * @deprecated Use {@link #sendAsync(Message, Executor)}. + */ + @Deprecated + public void setExecutor(Executor executor) { + this.defaultExecutor = executor == null ? ForkJoinPool.commonPool() : executor; + } + + @Override + public String toString() { + return "DohResolver {" + (usePost ? "POST " : "GET ") + uriTemplate + "}"; + } + + protected abstract CompletableFuture failedFuture(Throwable e); + + protected final CompletableFuture timeoutFailedFuture(Message query, Throwable inner) { + return timeoutFailedFuture(query, null, inner); + } + + protected final CompletableFuture timeoutFailedFuture( + Message query, String message, Throwable inner) { + return failedFuture( + new TimeoutException( + "Query " + + query.getHeader().getID() + + " for " + + query.getQuestion().getName() + + "/" + + Type.string(query.getQuestion().getType()) + + " timed out" + + (message != null ? ": " + message : "") + + (inner != null && inner.getMessage() != null ? ", " + inner.getMessage() : ""))); + } +} diff --git a/src/main/java/org/xbill/DNS/Lookup.java b/src/main/java/org/xbill/DNS/Lookup.java index 135be61e..875d03b0 100644 --- a/src/main/java/org/xbill/DNS/Lookup.java +++ b/src/main/java/org/xbill/DNS/Lookup.java @@ -8,6 +8,7 @@ import java.net.InetAddress; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -43,7 +44,7 @@ public final class Lookup { private static Resolver defaultResolver; - private static List defaultSearchPath; + private static List defaultSearchPath = Collections.emptyList(); private static Map defaultCaches; private static int defaultNdots; private static HostsFileParser defaultHostsFileParser; @@ -173,6 +174,11 @@ public static synchronized List getDefaultSearchPath() { * made absolute. */ public static synchronized void setDefaultSearchPath(List domains) { + if (domains == null) { + defaultSearchPath = Collections.emptyList(); + return; + } + defaultSearchPath = convertSearchPathDomainList(domains); } @@ -184,6 +190,11 @@ public static synchronized void setDefaultSearchPath(List domains) { * made absolute. */ public static synchronized void setDefaultSearchPath(Name... domains) { + if (domains == null) { + defaultSearchPath = Collections.emptyList(); + return; + } + setDefaultSearchPath(Arrays.asList(domains)); } @@ -196,16 +207,16 @@ public static synchronized void setDefaultSearchPath(Name... domains) { public static synchronized void setDefaultSearchPath(String... domains) throws TextParseException { if (domains == null) { - defaultSearchPath = null; + defaultSearchPath = Collections.emptyList(); return; } - List newdomains = new ArrayList<>(domains.length); + List newDomains = new ArrayList<>(domains.length); for (String domain : domains) { - newdomains.add(Name.fromString(domain, Name.root)); + newDomains.add(Name.fromString(domain, Name.root)); } - defaultSearchPath = newdomains; + defaultSearchPath = newDomains; } /** @@ -395,6 +406,11 @@ public void setResolver(Resolver resolver) { * made absolute. */ public void setSearchPath(List domains) { + if (domains == null) { + this.searchPath = Collections.emptyList(); + return; + } + this.searchPath = convertSearchPathDomainList(domains); } @@ -406,6 +422,11 @@ public void setSearchPath(List domains) { * made absolute. */ public void setSearchPath(Name... domains) { + if (domains == null) { + this.searchPath = Collections.emptyList(); + return; + } + setSearchPath(Arrays.asList(domains)); } @@ -417,15 +438,16 @@ public void setSearchPath(Name... domains) { */ public void setSearchPath(String... domains) throws TextParseException { if (domains == null) { - this.searchPath = null; + this.searchPath = Collections.emptyList(); return; } - List newdomains = new ArrayList<>(domains.length); + List newDomains = new ArrayList<>(domains.length); for (String domain : domains) { - newdomains.add(Name.fromString(domain, Name.root)); + newDomains.add(Name.fromString(domain, Name.root)); } - this.searchPath = newdomains; + + this.searchPath = newDomains; } /** @@ -671,8 +693,6 @@ public Record[] run() { } if (name.isAbsolute()) { resolve(name, null); - } else if (searchPath == null) { - resolve(name, Name.root); } else { if (name.labels() > ndots) { resolve(name, Name.root); diff --git a/src/main/java/org/xbill/DNS/Message.java b/src/main/java/org/xbill/DNS/Message.java index bb214c8d..767450b8 100644 --- a/src/main/java/org/xbill/DNS/Message.java +++ b/src/main/java/org/xbill/DNS/Message.java @@ -810,9 +810,11 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) List additionalSectionSets = getSectionRRsets(Section.ADDITIONAL); List authoritySectionSets = getSectionRRsets(Section.AUTHORITY); - List cleanedAnswerSection = new ArrayList<>(); - List cleanedAuthoritySection = new ArrayList<>(); - List cleanedAdditionalSection = new ArrayList<>(); + @SuppressWarnings("unchecked") + List[] cleanedSection = new ArrayList[4]; + cleanedSection[Section.ANSWER] = new ArrayList<>(); + cleanedSection[Section.AUTHORITY] = new ArrayList<>(); + cleanedSection[Section.ADDITIONAL] = new ArrayList<>(); boolean hadNsInAuthority = false; // For the ANSWER section, remove all "irrelevant" records and add synthesized CNAMEs from @@ -843,7 +845,7 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) // If DNAME was queried, don't attempt to synthesize CNAME if (query.getQuestion().getType() != Type.DNAME) { // The DNAME is valid, accept it - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); // Check if the next rrset is correct CNAME, otherwise synthesize a CNAME RRset nextRRSet = answerSectionSets.size() >= i + 2 ? answerSectionSets.get(i + 1) : null; @@ -863,7 +865,7 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) // Add a synthesized CNAME; TTL=0 to avoid caching Name dnameTarget = sname.fromDNAME(dname); - cleanedAnswerSection.add( + cleanedSection[Section.ANSWER].add( new RRset(new CNAMERecord(sname, dname.getDClass(), 0, dnameTarget))); sname = dnameTarget; @@ -872,7 +874,7 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) for (i++; i < answerSectionSets.size(); i++) { rrset = answerSectionSets.get(i); if (rrset.getName().equals(oldSname)) { - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); } else { break; } @@ -943,14 +945,14 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) } sname = ((CNAMERecord) rrset.first()).getTarget(); - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); // In CNAME ANY response, can have data after CNAME if (query.getQuestion().getType() == Type.ANY) { for (i++; i < answerSectionSets.size(); i++) { rrset = answerSectionSets.get(i); if (rrset.getName().equals(oldSname)) { - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); } else { break; } @@ -973,9 +975,9 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) } // Mark the additional names from relevant RRset as OK - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); if (sname.equals(rrset.getName())) { - addAdditionalRRset(rrset, additionalSectionSets, cleanedAdditionalSection); + addAdditionalRRset(rrset, additionalSectionSets, cleanedSection[Section.ADDITIONAL]); } } @@ -1045,15 +1047,25 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) } } - cleanedAuthoritySection.add(rrset); - addAdditionalRRset(rrset, additionalSectionSets, cleanedAdditionalSection); + cleanedSection[Section.AUTHORITY].add(rrset); + addAdditionalRRset(rrset, additionalSectionSets, cleanedSection[Section.ADDITIONAL]); } Message cleanedMessage = new Message(this.getHeader()); cleanedMessage.sections[Section.QUESTION] = this.sections[Section.QUESTION]; - cleanedMessage.sections[Section.ANSWER] = rrsetListToRecords(cleanedAnswerSection); - cleanedMessage.sections[Section.AUTHORITY] = rrsetListToRecords(cleanedAuthoritySection); - cleanedMessage.sections[Section.ADDITIONAL] = rrsetListToRecords(cleanedAdditionalSection); + for (int section : new int[] {Section.ANSWER, Section.AUTHORITY, Section.ADDITIONAL}) { + cleanedMessage.sections[section] = rrsetListToRecords(cleanedSection[section]); + + // Fixup counts in the header + cleanedMessage + .getHeader() + .setCount( + section, + cleanedMessage.sections[section] == null + ? 0 + : cleanedMessage.sections[section].size()); + } + return cleanedMessage; } diff --git a/src/main/java/org/xbill/DNS/NioClient.java b/src/main/java/org/xbill/DNS/NioClient.java index d9718e00..14211a29 100644 --- a/src/main/java/org/xbill/DNS/NioClient.java +++ b/src/main/java/org/xbill/DNS/NioClient.java @@ -10,6 +10,7 @@ import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.util.Iterator; +import java.util.function.Consumer; import lombok.AccessLevel; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -24,9 +25,9 @@ *

The following configuration parameter is available: * *

- *
dnsjava.nio.selector_timeout + *
{@value SELECTOR_TIMEOUT_PROPERTY} *
Set selector timeout in milliseconds. Default/Max 1000, Min 1. - *
dnsjava.nio.register_shutdown_hook + *
{@value REGISTER_SHUTDOWN_HOOK_PROPERTY} *
Register Shutdown Hook termination of NIO. Default True. *
* @@ -35,16 +36,24 @@ @Slf4j @NoArgsConstructor(access = AccessLevel.NONE) public abstract class NioClient { + static final String SELECTOR_TIMEOUT_PROPERTY = "dnsjava.nio.selector_timeout"; + static final String REGISTER_SHUTDOWN_HOOK_PROPERTY = "dnsjava.nio.register_shutdown_hook"; + private static final Object NIO_CLIENT_LOCK = new Object(); + /** Packet logger, if available. */ private static PacketLogger packetLogger = null; private static final Runnable[] TIMEOUT_TASKS = new Runnable[2]; - private static final Runnable[] REGISTRATIONS_TASKS = new Runnable[2]; + + @SuppressWarnings("unchecked") + private static final Consumer[] REGISTRATIONS_TASKS = new Consumer[2]; + private static final Runnable[] CLOSE_TASKS = new Runnable[2]; private static Thread selectorThread; private static Thread closeThread; private static volatile Selector selector; private static volatile boolean run; + private static volatile boolean closeDone; interface KeyProcessor { void processReadyKey(SelectionKey key); @@ -52,7 +61,7 @@ interface KeyProcessor { static Selector selector() throws IOException { if (selector == null) { - synchronized (NioClient.class) { + synchronized (NIO_CLIENT_LOCK) { if (selector == null) { selector = Selector.open(); log.debug("Starting dnsjava NIO selector thread"); @@ -63,8 +72,7 @@ static Selector selector() throws IOException { selectorThread.start(); closeThread = new Thread(() -> close(true)); closeThread.setName("dnsjava NIO shutdown hook"); - if (Boolean.parseBoolean( - System.getProperty("dnsjava.nio.register_shutdown_hook", "true"))) { + if (Boolean.parseBoolean(System.getProperty(REGISTER_SHUTDOWN_HOOK_PROPERTY, "true"))) { Runtime.getRuntime().addShutdownHook(closeThread); } } @@ -74,13 +82,23 @@ static Selector selector() throws IOException { return selector; } - /** Shutdown the network I/O used by the {@link SimpleResolver}. */ + /** + * Shutdown the network I/O used by the {@link SimpleResolver}. + * + * @implNote Does not wait until the selector thread has stopped. But users may immediately start + * using the {@link NioClient} again. + * @since 3.4.0 + */ public static void close() { close(false); } private static void close(boolean fromHook) { run = false; + Selector localSelector = selector; + if (localSelector != null) { + selector.wakeup(); + } if (!fromHook) { try { @@ -90,40 +108,26 @@ private static void close(boolean fromHook) { } } - try { - runTasks(CLOSE_TASKS); - } catch (Exception e) { - log.warn("Failed to execute shutdown task, ignoring and continuing close", e); - } - - Selector localSelector = selector; - Thread localSelectorThread = selectorThread; - synchronized (NioClient.class) { - selector = null; - selectorThread = null; - closeThread = null; + if (localSelector == null) { + // Prevent hanging when close() was called without starting + return; } - if (localSelector != null) { - localSelector.wakeup(); + synchronized (NIO_CLIENT_LOCK) { try { - localSelector.close(); - } catch (IOException e) { - log.warn("Failed to properly close selector, ignoring and continuing close", e); - } - } - - if (localSelectorThread != null) { - try { - localSelectorThread.join(); + while (!closeDone) { + NIO_CLIENT_LOCK.wait(); + } } catch (InterruptedException e) { Thread.currentThread().interrupt(); + } finally { + closeDone = false; } } } static void runSelector() { - int timeout = Integer.getInteger("dnsjava.nio.selector_timeout", 1000); + int timeout = Integer.getInteger(SELECTOR_TIMEOUT_PROPERTY, 1000); if (timeout <= 0 || timeout > 1000) { throw new IllegalArgumentException("Invalid selector_timeout, must be between 1 and 1000"); @@ -136,7 +140,7 @@ static void runSelector() { } if (run) { - runTasks(REGISTRATIONS_TASKS); + runRegistrationTasks(); processReadyKeys(); } } catch (IOException e) { @@ -145,30 +149,74 @@ static void runSelector() { // ignore } } + + runClose(); log.debug("dnsjava NIO selector thread stopped"); } - static synchronized void setTimeoutTask(Runnable r, boolean isTcpClient) { + private static void runClose() { + try { + runTasks(CLOSE_TASKS); + } catch (Exception e) { + log.warn("Failed to execute shutdown task, ignoring and continuing close", e); + } + + Selector localSelector = selector; + Thread localSelectorThread = selectorThread; + synchronized (NIO_CLIENT_LOCK) { + selector = null; + selectorThread = null; + closeThread = null; + closeDone = true; + NIO_CLIENT_LOCK.notifyAll(); + } + + if (localSelector != null) { + try { + localSelector.close(); + } catch (IOException e) { + log.warn("Failed to properly close selector, ignoring and continuing close", e); + } + } + + if (localSelectorThread != null) { + try { + localSelectorThread.join(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } + + static void setTimeoutTask(Runnable r, boolean isTcpClient) { addTask(TIMEOUT_TASKS, r, isTcpClient); } - static synchronized void setRegistrationsTask(Runnable r, boolean isTcpClient) { - addTask(REGISTRATIONS_TASKS, r, isTcpClient); + static void setRegistrationsTask(Consumer r, boolean isTcpClient) { + addRegistrationTask(r, isTcpClient); } - static synchronized void setCloseTask(Runnable r, boolean isTcpClient) { + static void setCloseTask(Runnable r, boolean isTcpClient) { addTask(CLOSE_TASKS, r, isTcpClient); } - private static void addTask(Runnable[] closeTasks, Runnable r, boolean isTcpClient) { + private static void addTask(Runnable[] tasks, Runnable r, boolean isTcpClient) { if (isTcpClient) { - closeTasks[0] = r; + tasks[0] = r; } else { - closeTasks[1] = r; + tasks[1] = r; } } - private static synchronized void runTasks(Runnable[] runnables) { + private static void addRegistrationTask(Consumer r, boolean isTcpClient) { + if (isTcpClient) { + REGISTRATIONS_TASKS[0] = r; + } else { + REGISTRATIONS_TASKS[1] = r; + } + } + + private static void runTasks(Runnable[] runnables) { Runnable r0 = runnables[0]; if (r0 != null) { r0.run(); @@ -179,6 +227,17 @@ private static synchronized void runTasks(Runnable[] runnables) { } } + private static void runRegistrationTasks() { + Consumer r0 = REGISTRATIONS_TASKS[0]; + if (r0 != null) { + r0.accept(selector); + } + Consumer r1 = REGISTRATIONS_TASKS[1]; + if (r1 != null) { + r1.accept(selector); + } + } + private static void processReadyKeys() { Iterator it = selector.selectedKeys().iterator(); while (it.hasNext()) { diff --git a/src/main/java/org/xbill/DNS/NioTcpClient.java b/src/main/java/org/xbill/DNS/NioTcpClient.java index 566c00eb..200a5331 100644 --- a/src/main/java/org/xbill/DNS/NioTcpClient.java +++ b/src/main/java/org/xbill/DNS/NioTcpClient.java @@ -32,7 +32,7 @@ final class NioTcpClient extends NioClient implements TcpIoClient { setCloseTask(this::closeTcp, true); } - private void processPendingRegistrations() { + private void processPendingRegistrations(Selector selector) { while (!registrationQueue.isEmpty()) { ChannelState state = registrationQueue.poll(); if (state == null) { @@ -40,7 +40,6 @@ private void processPendingRegistrations() { } try { - final Selector selector = selector(); if (!state.channel.isConnected()) { state.channel.register(selector, SelectionKey.OP_CONNECT, state); } else { diff --git a/src/main/java/org/xbill/DNS/NioUdpClient.java b/src/main/java/org/xbill/DNS/NioUdpClient.java index ce6607f2..e2ccfef2 100644 --- a/src/main/java/org/xbill/DNS/NioUdpClient.java +++ b/src/main/java/org/xbill/DNS/NioUdpClient.java @@ -54,7 +54,7 @@ final class NioUdpClient extends NioClient implements UdpIoClient { setCloseTask(this::closeUdp, false); } - private void processPendingRegistrations() { + private void processPendingRegistrations(Selector selector) { while (!registrationQueue.isEmpty()) { Transaction t = registrationQueue.poll(); if (t == null) { @@ -63,7 +63,7 @@ private void processPendingRegistrations() { try { log.trace("Registering OP_READ for transaction with id {}", t.id); - t.channel.register(selector(), SelectionKey.OP_READ, t); + t.channel.register(selector, SelectionKey.OP_READ, t); t.send(); } catch (IOException e) { t.completeExceptionally(e); diff --git a/src/main/java/org/xbill/DNS/Resolver.java b/src/main/java/org/xbill/DNS/Resolver.java index 5d3c0ff3..5b6ce511 100644 --- a/src/main/java/org/xbill/DNS/Resolver.java +++ b/src/main/java/org/xbill/DNS/Resolver.java @@ -9,6 +9,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -239,11 +240,16 @@ default Object sendAsync(Message query, ResolverListener listener) { (result, throwable) -> { if (throwable != null) { Exception exception; + if (throwable instanceof CompletionException && throwable.getCause() != null) { + throwable = throwable.getCause(); + } + if (throwable instanceof Exception) { exception = (Exception) throwable; } else { exception = new Exception(throwable); } + listener.handleException(id, exception); return null; } diff --git a/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java b/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java index 24796df6..e2ae60f6 100644 --- a/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java +++ b/src/main/java/org/xbill/DNS/TimeoutCompletableFuture.java @@ -1,8 +1,6 @@ // SPDX-License-Identifier: BSD-3-Clause package org.xbill.DNS; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.ScheduledThreadPoolExecutor; @@ -10,57 +8,28 @@ import java.util.concurrent.TimeoutException; import lombok.extern.slf4j.Slf4j; -/** - * Utility class to backport {@code orTimeout} to Java 8 with a custom implementation. On Java 9+ - * the built-in method is called. - */ +/** Utility class to backport {@code orTimeout} to Java 8 with a custom implementation. */ @Slf4j class TimeoutCompletableFuture extends CompletableFuture { - private static final Method orTimeoutMethod; - - static { - Method localOrTimeoutMethod; - if (!System.getProperty("java.version").startsWith("1.")) { - try { - localOrTimeoutMethod = - CompletableFuture.class.getMethod("orTimeout", long.class, TimeUnit.class); - } catch (NoSuchMethodException e) { - localOrTimeoutMethod = null; - log.warn( - "CompletableFuture.orTimeout method not found in Java 9+, using custom implementation", - e); - } - } else { - localOrTimeoutMethod = null; - } - orTimeoutMethod = localOrTimeoutMethod; - } - public CompletableFuture compatTimeout(long timeout, TimeUnit unit) { return compatTimeout(this, timeout, unit); } - @SuppressWarnings("unchecked") public static CompletableFuture compatTimeout( CompletableFuture f, long timeout, TimeUnit unit) { - if (orTimeoutMethod == null) { - return orTimeout(f, timeout, unit); - } else { - try { - return (CompletableFuture) orTimeoutMethod.invoke(f, timeout, unit); - } catch (IllegalAccessException | InvocationTargetException e) { - return orTimeout(f, timeout, unit); - } + if (timeout <= 0) { + f.completeExceptionally(new TimeoutException("timeout is " + timeout + ", but must be > 0")); } - } - private static CompletableFuture orTimeout( - CompletableFuture f, long timeout, TimeUnit unit) { ScheduledFuture sf = TimeoutScheduler.executor.schedule( () -> { if (!f.isDone()) { - f.completeExceptionally(new TimeoutException()); + f.completeExceptionally( + new TimeoutException( + "Timeout of " + + unit.toMillis(timeout) + + "ms has elapsed before the task completed")); } }, timeout, diff --git a/src/main/java/org/xbill/DNS/hosts/HostsFileParser.java b/src/main/java/org/xbill/DNS/hosts/HostsFileParser.java index c617bc6d..cd5dfb88 100644 --- a/src/main/java/org/xbill/DNS/hosts/HostsFileParser.java +++ b/src/main/java/org/xbill/DNS/hosts/HostsFileParser.java @@ -8,12 +8,15 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.time.Clock; +import java.time.Duration; import java.time.Instant; import java.util.Arrays; -import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.xbill.DNS.Address; @@ -31,12 +34,23 @@ public final class HostsFileParser { private final int maxFullCacheFileSizeBytes = Integer.parseInt(System.getProperty("dnsjava.hostsfile.max_size_bytes", "16384")); + private final Duration fileChangeCheckInterval = + Duration.ofMillis( + Integer.parseInt( + System.getProperty("dnsjava.hostsfile.change_check_interval_ms", "300000"))); - private final Map hostsCache = new HashMap<>(); private final Path path; private final boolean clearCacheOnChange; + private Clock clock = Clock.systemUTC(); + + @SuppressWarnings("java:S3077") + private volatile Map hostsCache; + + private Instant lastFileModificationCheckTime = Instant.MIN; private Instant lastFileReadTime = Instant.MIN; private boolean isEntireFileParsed; + private boolean hostsFileWarningLogged = false; + private long hostsFileSizeBytes; /** * Creates a new instance based on the current OS's default. Unix and alike (or rather everything @@ -86,8 +100,7 @@ public HostsFileParser(Path path, boolean clearCacheOnChange) { * @throws IllegalArgumentException when {@code type} is not {@link org.xbill.DNS.Type#A} or{@link * org.xbill.DNS.Type#AAAA}. */ - public synchronized Optional getAddressForHost(Name name, int type) - throws IOException { + public Optional getAddressForHost(Name name, int type) throws IOException { Objects.requireNonNull(name, "name is required"); if (type != Type.A && type != Type.AAAA) { throw new IllegalArgumentException("type can only be A or AAAA"); @@ -100,13 +113,11 @@ public synchronized Optional getAddressForHost(Name name, int type) return Optional.of(cachedAddress); } - if (isEntireFileParsed || !Files.exists(path)) { + if (isEntireFileParsed) { return Optional.empty(); } - if (Files.size(path) <= maxFullCacheFileSizeBytes) { - parseEntireHostsFile(); - } else { + if (hostsFileSizeBytes > maxFullCacheFileSizeBytes) { searchHostsFileForEntry(name, type); } @@ -116,9 +127,11 @@ public synchronized Optional getAddressForHost(Name name, int type) private void parseEntireHostsFile() throws IOException { String line; int lineNumber = 0; + AtomicInteger addressFailures = new AtomicInteger(0); + AtomicInteger nameFailures = new AtomicInteger(0); try (BufferedReader hostsReader = Files.newBufferedReader(path, StandardCharsets.UTF_8)) { while ((line = hostsReader.readLine()) != null) { - LineData lineData = parseLine(++lineNumber, line); + LineData lineData = parseLine(++lineNumber, line, addressFailures, nameFailures); if (lineData != null) { for (Name lineName : lineData.names) { InetAddress lineAddress = @@ -129,15 +142,24 @@ private void parseEntireHostsFile() throws IOException { } } - isEntireFileParsed = true; + if (!hostsFileWarningLogged && (addressFailures.get() > 0 || nameFailures.get() > 0)) { + log.warn( + "Failed to parse entire hosts file {}, address failures={}, name failures={}", + path, + addressFailures.get(), + nameFailures); + hostsFileWarningLogged = true; + } } private void searchHostsFileForEntry(Name name, int type) throws IOException { String line; int lineNumber = 0; + AtomicInteger addressFailures = new AtomicInteger(0); + AtomicInteger nameFailures = new AtomicInteger(0); try (BufferedReader hostsReader = Files.newBufferedReader(path, StandardCharsets.UTF_8)) { while ((line = hostsReader.readLine()) != null) { - LineData lineData = parseLine(++lineNumber, line); + LineData lineData = parseLine(++lineNumber, line, addressFailures, nameFailures); if (lineData != null) { for (Name lineName : lineData.names) { boolean isSearchedEntry = lineName.equals(name); @@ -151,6 +173,16 @@ private void searchHostsFileForEntry(Name name, int type) throws IOException { } } } + + if (!hostsFileWarningLogged && (addressFailures.get() > 0 || nameFailures.get() > 0)) { + log.warn( + "Failed to find {} in hosts file {}, address failures={}, name failures={}", + name, + path, + addressFailures.get(), + nameFailures); + hostsFileWarningLogged = true; + } } @RequiredArgsConstructor @@ -160,7 +192,8 @@ private static final class LineData { final Iterable names; } - private LineData parseLine(int lineNumber, String line) { + private LineData parseLine( + int lineNumber, String line, AtomicInteger addressFailures, AtomicInteger nameFailures) { String[] lineTokens = getLineTokens(line); if (lineTokens.length < 2) { return null; @@ -174,24 +207,26 @@ private LineData parseLine(int lineNumber, String line) { } if (lineAddressBytes == null) { - log.warn("Could not decode address {}, {}#L{}", lineTokens[0], path, lineNumber); + log.debug("Could not decode address {}, {}#L{}", lineTokens[0], path, lineNumber); + addressFailures.incrementAndGet(); return null; } Iterable lineNames = Arrays.stream(lineTokens) .skip(1) - .map(lineTokenName -> safeName(lineTokenName, lineNumber)) + .map(lineTokenName -> safeName(lineTokenName, lineNumber, nameFailures)) .filter(Objects::nonNull) ::iterator; return new LineData(lineAddressType, lineAddressBytes, lineNames); } - private Name safeName(String name, int lineNumber) { + private Name safeName(String name, int lineNumber, AtomicInteger nameFailures) { try { return Name.fromString(name, Name.root); } catch (TextParseException e) { - log.warn("Could not decode name {}, {}#L{}, skipping", name, path, lineNumber); + log.debug("Could not decode name {}, {}#L{}, skipping", name, path, lineNumber); + nameFailures.incrementAndGet(); return null; } } @@ -207,21 +242,61 @@ private String[] getLineTokens(String line) { } private void validateCache() throws IOException { - if (clearCacheOnChange) { + if (!clearCacheOnChange) { + if (hostsCache == null) { + synchronized (this) { + if (hostsCache == null) { + readHostsFile(); + } + } + } + + return; + } + + if (lastFileModificationCheckTime.plus(fileChangeCheckInterval).isBefore(clock.instant())) { + log.debug("Checked for changes more than 5minutes ago, checking"); // A filewatcher / inotify etc. would be nicer, but doesn't work. c.f. the write up at // https://blog.arkey.fr/2019/09/13/watchservice-and-bind-mount/ - Instant fileTime = - Files.exists(path) ? Files.getLastModifiedTime(path).toInstant() : Instant.MAX; - if (fileTime.isAfter(lastFileReadTime)) { - // skip logging noise when the cache is empty anyway - if (!hostsCache.isEmpty()) { - log.info("Local hosts database has changed at {}, clearing cache", fileTime); - hostsCache.clear(); + + synchronized (this) { + if (!lastFileModificationCheckTime + .plus(fileChangeCheckInterval) + .isBefore(clock.instant())) { + log.debug("Never mind, check fulfilled in another thread"); + return; + } + + lastFileModificationCheckTime = clock.instant(); + readHostsFile(); + } + } + } + + private void readHostsFile() throws IOException { + if (Files.exists(path)) { + Instant fileTime = Files.getLastModifiedTime(path).toInstant(); + if (!lastFileReadTime.equals(fileTime)) { + createOrClearCache(); + + hostsFileSizeBytes = Files.size(path); + if (hostsFileSizeBytes <= maxFullCacheFileSizeBytes) { + parseEntireHostsFile(); + isEntireFileParsed = true; } - isEntireFileParsed = false; lastFileReadTime = fileTime; } + } else { + createOrClearCache(); + } + } + + private void createOrClearCache() { + if (hostsCache == null) { + hostsCache = new ConcurrentHashMap<>(); + } else { + hostsCache.clear(); } } @@ -231,6 +306,10 @@ private String key(Name name, int type) { // for unit testing only int cacheSize() { - return hostsCache.size(); + return hostsCache == null ? 0 : hostsCache.size(); + } + + void setClock(Clock clock) { + this.clock = clock; } } diff --git a/src/main/java11/org/xbill/DNS/AsyncSemaphore.java b/src/main/java11/org/xbill/DNS/AsyncSemaphore.java new file mode 100644 index 00000000..d2dcf6dd --- /dev/null +++ b/src/main/java11/org/xbill/DNS/AsyncSemaphore.java @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS; + +import java.time.Duration; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import lombok.extern.slf4j.Slf4j; + +@Slf4j +final class AsyncSemaphore { + private final Queue> queue = new ArrayDeque<>(); + private final Permit singletonPermit = new Permit(); + private final String name; + private volatile int permits; + + final class Permit { + public void release(int id, Executor executor) { + synchronized (queue) { + CompletableFuture next = queue.poll(); + if (next == null) { + permits++; + log.trace("{} permit released id={}, available={}", name, id, permits); + } else { + log.trace("{} permit released id={}, available={}, immediate next", name, id, permits); + next.completeAsync(() -> this, executor); + } + } + } + } + + AsyncSemaphore(int permits, String name) { + this.permits = permits; + this.name = name; + log.debug("Using Java 11+ implementation for {}", name); + } + + CompletionStage acquire(Duration timeout, int id, Executor executor) { + synchronized (queue) { + if (permits > 0) { + permits--; + log.trace("{} permit acquired id={}, available={}", name, id, permits); + return CompletableFuture.completedFuture(singletonPermit); + } else { + CompletableFuture f = new CompletableFuture<>(); + f.orTimeout(timeout.toNanos(), TimeUnit.NANOSECONDS) + .whenCompleteAsync( + (result, ex) -> { + synchronized (queue) { + if (ex != null) { + log.trace("{} permit timed out id={}, available={}", name, id, permits); + } + queue.remove(f); + } + }, + executor); + log.trace("{} permit queued id={}, available={}", name, id, permits); + queue.add(f); + return f; + } + } + } +} diff --git a/src/main/java11/org/xbill/DNS/DohResolver.java b/src/main/java11/org/xbill/DNS/DohResolver.java new file mode 100644 index 00000000..8e73eca5 --- /dev/null +++ b/src/main/java11/org/xbill/DNS/DohResolver.java @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS; + +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.net.http.HttpTimeoutException; +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.WeakHashMap; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Function; +import lombok.SneakyThrows; +import lombok.extern.slf4j.Slf4j; +import org.xbill.DNS.AsyncSemaphore.Permit; + +/** + * Proof-of-concept DNS over HTTP (DoH) + * resolver. This class is not suitable for high load scenarios because of the shortcomings of + * Java's built-in HTTP clients. For more control, implement your own {@link Resolver} using e.g. OkHttp. + * + *

On Java 8, it uses HTTP/1.1, which is against the recommendation of RFC 8484 to use HTTP/2 and + * thus slower. On Java 11 or newer, HTTP/2 is always used, but the built-in HttpClient has its own + * issues with connection handling. + * + *

As of 2020-09-13, the following limits of public resolvers for HTTP/2 were observed: + *

  • https://cloudflare-dns.com/dns-query: max streams=250, idle timeout=400s + *
  • https://dns.google/dns-query: max streams=100, idle timeout=240s + * + * @since 3.0 + */ +@Slf4j +public final class DohResolver extends DohResolverCommon { + private static final String APPLICATION_DNS_MESSAGE = "application/dns-message"; + private static final Map httpClients = + Collections.synchronizedMap(new WeakHashMap<>()); + private static final HttpRequest.Builder defaultHttpRequestBuilder; + + private final AsyncSemaphore initialRequestLock = new AsyncSemaphore(1, "initial request"); + + private final Duration idleConnectionTimeout; + + static { + defaultHttpRequestBuilder = HttpRequest.newBuilder(); + defaultHttpRequestBuilder.version(HttpClient.Version.HTTP_2); + defaultHttpRequestBuilder.header("Content-Type", APPLICATION_DNS_MESSAGE); + defaultHttpRequestBuilder.header("Accept", APPLICATION_DNS_MESSAGE); + } + + /** + * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). + * + * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} + */ + public DohResolver(String uriTemplate) { + this(uriTemplate, 100, Duration.ofMinutes(2)); + } + + /** + * Creates a new DoH resolver that performs lookups with HTTP GET and the default timeout (5s). + * + * @param uriTemplate the URI to use for resolving, e.g. {@code https://dns.google/dns-query} + * @param maxConcurrentRequests Maximum concurrent HTTP/2 streams for Java 11+ or HTTP/1.1 + * connections for Java 8. On Java 8 this cannot exceed the system property {@code + * http.maxConnections}. + * @param idleConnectionTimeout Max. idle time for HTTP/2 connections until a request is + * serialized. Applies to Java 11+ only. + * @since 3.3 + */ + public DohResolver( + String uriTemplate, int maxConcurrentRequests, Duration idleConnectionTimeout) { + super(uriTemplate, maxConcurrentRequests); + log.debug("Using Java 11+ implementation"); + this.idleConnectionTimeout = idleConnectionTimeout; + } + + @SneakyThrows + private HttpClient getHttpClient(Executor executor) { + return httpClients.computeIfAbsent( + executor, + key -> { + try { + return HttpClient.newBuilder().connectTimeout(timeout).executor(executor).build(); + } catch (IllegalArgumentException e) { + log.warn("Could not create a HttpClient for Executor {}", key, e); + return null; + } + }); + } + + @Override + public void setTimeout(Duration timeout) { + this.timeout = timeout; + httpClients.clear(); + } + + /** + * Sets the EDNS information on outgoing messages. + * + * @param version The EDNS version to use. 0 indicates EDNS0 and -1 indicates no EDNS. + * @param payloadSize ignored + * @param flags EDNS extended flags to be set in the OPT record. + * @param options EDNS options to be set in the OPT record + */ + @Override + @SuppressWarnings("java:S1185") // required for source- and binary compatibility + public void setEDNS(int version, int payloadSize, int flags, List options) { + // required for source- and binary compatibility + super.setEDNS(version, payloadSize, flags, options); + } + + @Override + @SuppressWarnings("java:S1185") // required for source- and binary compatibility + public CompletionStage sendAsync(Message query) { + return this.sendAsync(query, defaultExecutor); + } + + @Override + public CompletionStage sendAsync(Message query, Executor executor) { + long startTime = getNanoTime(); + byte[] queryBytes = prepareQuery(query).toWire(); + String url = getUrl(queryBytes); + + var requestBuilder = defaultHttpRequestBuilder.copy(); + requestBuilder.uri(URI.create(url)); + if (usePost) { + requestBuilder.POST(HttpRequest.BodyPublishers.ofByteArray(queryBytes)); + } + + // check if this request needs to be done synchronously because of HttpClient's stupidity to + // not use the connection pool for HTTP/2 until one connection is successfully established, + // which could lead to hundreds of connections (and threads with the default executor) + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + return timeoutFailedFuture(query, "no time left to acquire lock for first request", null); + } + + return initialRequestLock + .acquire(remainingTimeout, query.getHeader().getID(), executor) + .handle( + (initialRequestPermit, initialRequestEx) -> { + if (initialRequestEx != null) { + return this.timeoutFailedFuture(query, initialRequestEx); + } else { + return sendAsyncWithInitialRequestPermit( + query, executor, startTime, requestBuilder, initialRequestPermit); + } + }) + .thenCompose(Function.identity()); + } + + private CompletionStage sendAsyncWithInitialRequestPermit( + Message query, + Executor executor, + long startTime, + HttpRequest.Builder requestBuilder, + Permit initialRequestPermit) { + int queryId = query.getHeader().getID(); + long lastRequestTime = lastRequest.get(); + long requestDeltaNanos = getNanoTime() - lastRequestTime; + boolean isInitialRequest = + lastRequestTime == 0 || idleConnectionTimeout.toNanos() < requestDeltaNanos; + if (!isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + + // check if we already exceeded the query timeout while checking the initial connection + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + if (isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + + return timeoutFailedFuture( + query, "no time left to acquire lock for concurrent request", null); + } + + // Lock a HTTP/2 stream. Another stupidity of HttpClient to not simply queue the + // request, but fail with an IOException which also CLOSES the connection... *facepalm* + return maxConcurrentRequests + .acquire(remainingTimeout, queryId, executor) + .handle( + (maxConcurrentRequestPermit, maxConcurrentRequestEx) -> { + if (maxConcurrentRequestEx != null) { + if (isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + return this.timeoutFailedFuture( + query, + "timed out waiting for a concurrent request lease", + maxConcurrentRequestEx); + } else { + return sendAsyncWithConcurrentRequestPermit( + query, + executor, + startTime, + requestBuilder, + initialRequestPermit, + isInitialRequest, + maxConcurrentRequestPermit); + } + }) + .thenCompose(Function.identity()); + } + + private CompletionStage sendAsyncWithConcurrentRequestPermit( + Message query, + Executor executor, + long startTime, + HttpRequest.Builder requestBuilder, + Permit initialRequestPermit, + boolean isInitialRequest, + Permit maxConcurrentRequestPermit) { + int queryId = query.getHeader().getID(); + + // check if the stream lock acquisition took too long + Duration remainingTimeout = timeout.minus(getNanoTime() - startTime, ChronoUnit.NANOS); + if (remainingTimeout.toMillis() <= 0) { + if (isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + + maxConcurrentRequestPermit.release(queryId, executor); + return timeoutFailedFuture( + query, "no time left to acquire lock for concurrent request", null); + } + + var httpRequest = requestBuilder.timeout(remainingTimeout).build(); + var bodyHandler = HttpResponse.BodyHandlers.ofByteArray(); + return getHttpClient(executor) + .sendAsync(httpRequest, bodyHandler) + .whenComplete( + (result, ex) -> { + if (ex == null) { + lastRequest.set(startTime); + } + maxConcurrentRequestPermit.release(queryId, executor); + if (isInitialRequest) { + initialRequestPermit.release(queryId, executor); + } + }) + .handleAsync( + (response, ex) -> { + if (ex != null) { + if (ex instanceof HttpTimeoutException) { + return this.timeoutFailedFuture( + query, "http request did not complete", ex.getCause()); + } else { + return CompletableFuture.failedFuture(ex); + } + } else { + try { + Message responseMessage; + int rc = response.statusCode(); + if (rc >= 200 && rc < 300) { + byte[] responseBytes = response.body(); + responseMessage = new Message(responseBytes); + verifyTSIG(query, responseMessage, responseBytes, tsig); + } else { + responseMessage = new Message(); + responseMessage.getHeader().setRcode(Rcode.SERVFAIL); + } + + responseMessage.setResolver(this); + return CompletableFuture.completedFuture(responseMessage); + } catch (IOException e) { + return CompletableFuture.failedFuture(e); + } + } + }, + executor) + .thenCompose(Function.identity()) + .orTimeout(remainingTimeout.toMillis(), TimeUnit.MILLISECONDS) + .exceptionally( + ex -> { + if (ex instanceof TimeoutException) { + throw new CompletionException( + new TimeoutException( + "Query " + + query.getHeader().getID() + + " for " + + query.getQuestion().getName() + + "/" + + Type.string(query.getQuestion().getType()) + + " timed out in remaining " + + remainingTimeout.toMillis() + + "ms")); + } else if (ex instanceof CompletionException) { + throw (CompletionException) ex; + } + + throw new CompletionException(ex); + }); + } + + @Override + protected CompletableFuture failedFuture(Throwable e) { + return CompletableFuture.failedFuture(e); + } +} diff --git a/src/test/java/org/xbill/DNS/DohResolverTest.java b/src/test/java/org/xbill/DNS/DohResolverTest.java index 26daf19c..6d1500d6 100644 --- a/src/test/java/org/xbill/DNS/DohResolverTest.java +++ b/src/test/java/org/xbill/DNS/DohResolverTest.java @@ -28,8 +28,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledForJreRange; import org.junit.jupiter.api.condition.JRE; import org.junit.jupiter.api.extension.ExtendWith; @@ -38,8 +38,8 @@ import org.mockito.stubbing.Answer; @ExtendWith(VertxExtension.class) +@Slf4j class DohResolverTest { - private DohResolver resolver; private final Name queryName = Name.fromConstantString("example.com."); private final Record qr = Record.newRecord(queryName, Type.A, DClass.IN); private final Message qm = Message.newQuery(qr); @@ -48,7 +48,6 @@ class DohResolverTest { @BeforeEach void beforeEach() throws UnknownHostException { - resolver = new DohResolver("http://localhost"); Record ar = new ARecord( Name.fromConstantString("example.com."), @@ -59,11 +58,16 @@ void beforeEach() throws UnknownHostException { a.addRecord(ar, Section.ANSWER); } + private DohResolver getResolver() { + return new DohResolver("http://localhost"); + } + @ParameterizedTest @ValueSource(booleans = {false, true}) void simpleResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = getResolver(); resolver.setUsePost(usePost); - setupResolverWithServer(Duration.ZERO, 200, 1, vertx, context) + setupResolverWithServer(resolver, Duration.ZERO, 200, 1, vertx, context) .onSuccess( server -> Future.fromCompletionStage(resolver.sendAsync(qm)) @@ -79,10 +83,13 @@ void simpleResolve(boolean usePost, Vertx vertx, VertxTestContext context) { })))); } - @Test - void timeoutResolve(Vertx vertx, VertxTestContext context) { + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void timeoutResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = getResolver(); resolver.setTimeout(Duration.ofSeconds(1)); - setupResolverWithServer(Duration.ofSeconds(5), 200, 1, vertx, context) + resolver.setUsePost(usePost); + setupResolverWithServer(resolver, Duration.ofSeconds(5), 200, 1, vertx, context) .onSuccess( server -> Future.fromCompletionStage(resolver.sendAsync(qm)) @@ -98,9 +105,12 @@ void timeoutResolve(Vertx vertx, VertxTestContext context) { })))); } - @Test - void servfailResolve(Vertx vertx, VertxTestContext context) { - setupResolverWithServer(Duration.ZERO, 301, 1, vertx, context) + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void servfailResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = getResolver(); + resolver.setUsePost(usePost); + setupResolverWithServer(resolver, Duration.ZERO, 301, 1, vertx, context) .onSuccess( server -> Future.fromCompletionStage(resolver.sendAsync(qm)) @@ -114,12 +124,14 @@ void servfailResolve(Vertx vertx, VertxTestContext context) { })))); } - @Test - void limitRequestsResolve(Vertx vertx, VertxTestContext context) { - resolver = new DohResolver("http://localhost", 5, Duration.ofMinutes(2)); + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void limitRequestsResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = new DohResolver("http://localhost", 5, Duration.ofMinutes(2)); + resolver.setUsePost(usePost); int requests = 100; Checkpoint cpPass = context.checkpoint(requests); - setupResolverWithServer(Duration.ofMillis(100), 200, 5, vertx, context) + setupResolverWithServer(resolver, Duration.ofMillis(100), 200, 5, vertx, context) .onSuccess( server -> { for (int i = 0; i < requests; i++) { @@ -137,13 +149,15 @@ void limitRequestsResolve(Vertx vertx, VertxTestContext context) { }); } - @Test - void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) { - resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2)); + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void initialRequestSlowResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2)); + resolver.setUsePost(usePost); int requests = 20; allRequestsUseTimeout = false; Checkpoint cpPass = context.checkpoint(requests); - setupResolverWithServer(Duration.ofSeconds(1), 200, 2, vertx, context) + setupResolverWithServer(resolver, Duration.ofSeconds(1), 200, 2, vertx, context) .onSuccess( server -> { for (int i = 0; i < requests; i++) { @@ -161,19 +175,23 @@ void initialRequestSlowResolve(Vertx vertx, VertxTestContext context) { }); } - @Test - void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) { - resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2)); + @ParameterizedTest + @ValueSource(booleans = {false, true}) + void initialRequestTimeoutResolve(boolean usePost, Vertx vertx, VertxTestContext context) { + DohResolver resolver = new DohResolver("http://localhost", 2, Duration.ofMinutes(2)); + resolver.setUsePost(usePost); resolver.setTimeout(Duration.ofSeconds(1)); int requests = 20; allRequestsUseTimeout = false; Checkpoint cpPass = context.checkpoint(requests - 1); Checkpoint cpFail = context.checkpoint(); - setupResolverWithServer(Duration.ofSeconds(2), 200, 2, vertx, context) + setupResolverWithServer(resolver, Duration.ofSeconds(2), 200, 2, vertx, context) .onSuccess( server -> { + Message q = qm.clone(); + q.getHeader().setID(0); resolver - .sendAsync(qm) + .sendAsync(q) .whenComplete( (result, ex) -> { if (ex == null) { @@ -185,9 +203,11 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) { vertx.setTimer( 1000, timer -> { - for (int i = 0; i < requests - 1; i++) { + for (int i = 1; i < requests; i++) { + Message qq = qm.clone(); + qq.getHeader().setID(i); resolver - .sendAsync(qm) + .sendAsync(qq) .whenComplete( (result, ex) -> { if (ex == null) { @@ -202,6 +222,7 @@ void initialRequestTimeoutResolve(Vertx vertx, VertxTestContext context) { } private Future setupResolverWithServer( + DohResolver resolver, Duration responseDelay, int statusCode, int maxConcurrentRequests, @@ -214,12 +235,14 @@ private Future setupResolverWithServer( @EnabledForJreRange( min = JRE.JAVA_9, disabledReason = "Java 8 implementation doesn't have the initial request guard") - @Test + @ParameterizedTest + @ValueSource(booleans = {false, true}) void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime( - Vertx vertx, VertxTestContext context) { + boolean usePost, Vertx vertx, VertxTestContext context) { AtomicLong startNanos = new AtomicLong(System.nanoTime()); - resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2))); + DohResolver resolver = spy(new DohResolver("http://localhost", 2, Duration.ofMinutes(2))); resolver.setTimeout(Duration.ofSeconds(1)); + resolver.setUsePost(usePost); // Simulate a nanoTime value that is lower than the idle timeout doAnswer((Answer) invocationOnMock -> System.nanoTime() - startNanos.get()) .when(resolver) @@ -243,11 +266,13 @@ void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime( AtomicBoolean firstCallCompleted = new AtomicBoolean(false); - setupResolverWithServer(Duration.ofMillis(100L), 200, 2, vertx, context) + setupResolverWithServer(resolver, Duration.ofMillis(100L), 200, 2, vertx, context) .onSuccess( server -> { // First call - CompletionStage firstCall = resolver.sendAsync(qm); + CompletionStage firstCall = + resolver.sendAsync(qm).whenComplete((msg, ex) -> firstCallCompleted.set(true)); + // Ensure second call was made after first call and uses a different query startNanos.addAndGet(TimeUnit.MILLISECONDS.toNanos(20)); CompletionStage secondCall = resolver.sendAsync(Message.newQuery(qr)); @@ -261,7 +286,6 @@ void initialRequestGuardIfIdleConnectionTimeIsLargerThanSystemNanoTime( assertEquals(Rcode.NOERROR, result.getHeader().getRcode()); assertEquals(0, result.getHeader().getID()); assertEquals(queryName, result.getQuestion().getName()); - firstCallCompleted.set(true); }))); Future.fromCompletionStage(secondCall) @@ -302,10 +326,12 @@ private Future setupServer( int thisRequestNum = requestCount.incrementAndGet(); int count = concurrentRequests.incrementAndGet(); if (count > maxConcurrentRequests) { - context.failNow("Concurrent requests exceeded"); + context.failNow( + "Concurrent requests exceeded: " + count + " > " + maxConcurrentRequests); return; } + httpRequest.endHandler(v -> concurrentRequests.decrementAndGet()); httpRequest.bodyHandler( body -> { context.verify( @@ -332,15 +358,12 @@ private Future setupServer( && (thisRequestNum == 1 || allRequestsUseTimeout)) { vertx.setTimer( serverProcessingTime.toMillis(), - timer -> { - concurrentRequests.decrementAndGet(); - httpRequest - .response() - .setStatusCode(statusCode) - .end(Buffer.buffer(dnsResponseCopy.toWire())); - }); + timer -> + httpRequest + .response() + .setStatusCode(statusCode) + .end(Buffer.buffer(dnsResponseCopy.toWire()))); } else { - concurrentRequests.decrementAndGet(); httpRequest .response() .setStatusCode(statusCode) diff --git a/src/test/java/org/xbill/DNS/MessageTest.java b/src/test/java/org/xbill/DNS/MessageTest.java index 54e32d5c..48639a97 100644 --- a/src/test/java/org/xbill/DNS/MessageTest.java +++ b/src/test/java/org/xbill/DNS/MessageTest.java @@ -35,6 +35,7 @@ // package org.xbill.DNS; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -177,7 +178,9 @@ void normalize() throws WireParseException { response.addRecord(queryRecord, Section.QUESTION); response.addRecord(queryRecord, Section.ADDITIONAL); response = response.normalize(query, true); - assertTrue(response.getSection(Section.ANSWER).isEmpty()); - assertTrue(response.getSection(Section.ADDITIONAL).isEmpty()); + assertThat(response.getSection(Section.ANSWER)).isEmpty(); + assertThat(response.getHeader().getCount(Section.ANSWER)).isZero(); + assertThat(response.getSection(Section.ADDITIONAL)).isEmpty(); + assertThat(response.getHeader().getCount(Section.ADDITIONAL)).isZero(); } } diff --git a/src/test/java/org/xbill/DNS/NioTcpClientTest.java b/src/test/java/org/xbill/DNS/NioTcpClientTest.java index 5ef34095..a985714b 100644 --- a/src/test/java/org/xbill/DNS/NioTcpClientTest.java +++ b/src/test/java/org/xbill/DNS/NioTcpClientTest.java @@ -1,25 +1,36 @@ // SPDX-License-Identifier: BSD-3-Clause package org.xbill.DNS; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; +import static org.xbill.DNS.NioClient.SELECTOR_TIMEOUT_PROPERTY; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.lang.reflect.Field; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketTimeoutException; import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.Pipe; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; import java.time.Duration; import java.util.ArrayList; +import java.util.ConcurrentModificationException; import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import lombok.extern.slf4j.Slf4j; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; @@ -29,8 +40,6 @@ @Slf4j class NioTcpClientTest { - private static final String SELECTOR_TIMEOUT_PROPERTY = "dnsjava.nio.selector_timeout"; - @Test void testCloseWithoutStart() { assertDoesNotThrow(NioClient::close); @@ -47,6 +56,136 @@ void testSelectorTimeoutLimits(int timeout) { } } + /** + * Verifies that NioClient.processReadyKeys() does not throw a ConcurrentModificationException + * when channels are registered with the selector while processReadyKeys() is iterating over + * selector.selectedKeys(). This simulates concurrent modifications that can occur under high + * load. + * + *

    The test starts the selector thread, registers an initial key, and then rapidly registers + * new channels from another thread while making them ready. It fails if a + * ConcurrentModificationException is observed during the process. + * + *

    Since this involves concurrency timing, the test may be flaky and not fail consistently, + * even if the underlying issue exists. + */ + @Test + void testProcessReadyKeysShouldNotThrowConcurrentModificationException() throws Exception { + // Speed up selector loop + System.setProperty(SELECTOR_TIMEOUT_PROPERTY, "10"); + + try { + // Start selector thread + Selector selector = NioClient.selector(); + + // Add initial key to ensure selectedKeys isn't empty + Pipe initialPipe = Pipe.open(); + initialPipe.source().configureBlocking(false); + SelectionKey key = initialPipe.source().register(selector, SelectionKey.OP_READ); + key.attach( + (NioClient.KeyProcessor) + readyKey -> { + try { + // Slow down processing + Thread.sleep(10); + } catch (InterruptedException ignored) { + // ignore + } + }); + + // Watch for unexpected exceptions + AtomicReference sawCME = new AtomicReference<>(null); + Field selectorThreadField = NioClient.class.getDeclaredField("selectorThread"); + selectorThreadField.setAccessible(true); + ((Thread) selectorThreadField.get(null)) + .setUncaughtExceptionHandler( + (t, e) -> { + if (e instanceof ConcurrentModificationException) { + // Violation of expected behavior + sawCME.set(e); + } + }); + + List allPipes = new ArrayList<>(); + Queue registrationQueue = new ConcurrentLinkedQueue<>(); + NioClient.setRegistrationsTask( + sel -> { + Pipe pipe = registrationQueue.poll(); + if (pipe != null) { + try { + SelectionKey sk = pipe.source().register(sel, SelectionKey.OP_READ); + sk.attach( + (NioClient.KeyProcessor) + readyKey -> { + try { + Thread.sleep(10); + } catch (InterruptedException ignored) { + // ignore + } + }); + } catch (ClosedChannelException e) { + throw new RuntimeException(e); + } + } + }, + true); + + // Thread that registers new channels and makes some ready + Thread t1 = + new Thread( + () -> { + try { + for (int i = 0; i < 100; i++) { + Pipe pipe = Pipe.open(); + pipe.source().configureBlocking(false); + registrationQueue.add(pipe); + allPipes.add(pipe); + + // Make the channel ready to trigger selectedKeys modification + if (i % 2 == 0) { + pipe.sink().write(ByteBuffer.wrap("x".getBytes())); + } + + // Help trigger overlap + Thread.sleep(2); + } + } catch (Exception ignored) { + // ignore + } + }); + + // Thread that makes the remaining registrations ready + Thread t2 = + new Thread( + () -> { + try { + for (int i = 0; i < 100; i++) { + // Make the channel ready to trigger selectedKeys modification + if (i % 2 != 0) { + allPipes.get(i).sink().write(ByteBuffer.wrap("x".getBytes())); + } + + // Help trigger overlap + Thread.sleep(2); + } + } catch (Exception ignored) { + // ignore + } + }); + + t1.start(); + t2.start(); + t1.join(5000); + t2.join(5000); + NioClient.close(); + + // Assume correctness: fail if any exception was observed + assertThat(sawCME.get()).isNull(); + } finally { + System.clearProperty(SELECTOR_TIMEOUT_PROPERTY); + } + } + @Test void testResponseStream() throws InterruptedException, IOException { try { @@ -59,8 +198,7 @@ void testResponseStream() throws InterruptedException, IOException { for (int i = 0; i < q.length; i++) { q[i] = Message.newQuery(qr); // This is not actually valid data, but it increases the payload sufficiently to fill the - // send buffer, - // forcing NioTcpClient.Transaction#send into the retry + // send buffer, forcing NioTcpClient.Transaction#send into the retry // see https://github.com/dnsjava/dnsjava/issues/357 for (int j = 0; j < 2048; j++) { q[i].addRecord( diff --git a/src/test/java/org/xbill/DNS/ResolverTest.java b/src/test/java/org/xbill/DNS/ResolverTest.java new file mode 100644 index 00000000..365254a8 --- /dev/null +++ b/src/test/java/org/xbill/DNS/ResolverTest.java @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + +import java.net.UnknownHostException; +import java.time.Duration; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; + +class ResolverTest { + @Test + @SuppressWarnings("deprecation") + void resolverListenerExceptionUnwrap() throws InterruptedException, UnknownHostException { + // 1. Point to a blackhole address from RFC 5737 TEST-NET-1 to ensure a timeout + SimpleResolver resolver = new SimpleResolver("192.0.2.1"); + resolver.setTimeout(Duration.ofSeconds(2)); + + Message query = + Message.newQuery( + Record.newRecord(Name.fromConstantString("example.com."), Type.A, DClass.IN)); + CountDownLatch latch = new CountDownLatch(1); + + // 2. Use the async method with a listener + resolver.sendAsync( + query, + new ResolverListener() { + @Override + public void receiveMessage(Object id, Message m) { + fail("Received message (should not happen)"); + latch.countDown(); + } + + @Override + public void handleException(Object id, Exception ex) { + // 3. Observe the exception type + assertThat(ex).isNotInstanceOf(CompletionException.class); + latch.countDown(); + } + }); + + latch.await(5, TimeUnit.SECONDS); + } +} diff --git a/src/test/java/org/xbill/DNS/TSIGTest.java b/src/test/java/org/xbill/DNS/TSIGTest.java index a83fb6e0..416e8747 100644 --- a/src/test/java/org/xbill/DNS/TSIGTest.java +++ b/src/test/java/org/xbill/DNS/TSIGTest.java @@ -2,6 +2,7 @@ package org.xbill.DNS; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -520,6 +521,19 @@ TCPClient createTcpClient(Duration timeout) throws IOException { assertEquals(202, handler.getRecords().size()); } + @Test + void invalidAdditionalCount() { + Message q = Message.newQuery(Record.newRecord(Name.root, Type.A, DClass.IN)); + Message m = new Message(); + m.addRecord(Record.newRecord(Name.root, Type.A, DClass.IN), Section.QUESTION); + m.addRecord(Record.newRecord(Name.root, Type.A, DClass.IN), Section.ANSWER); + m.addRecord( + Record.newRecord(Name.fromConstantString("example.com."), Type.A, DClass.IN), + Section.ADDITIONAL); + assertDoesNotThrow(m::getTSIG); + assertDoesNotThrow(() -> m.normalize(q).getTSIG()); + } + @Getter private static class ZoneBuilderAxfrHandler implements ZoneTransferIn.ZoneTransferHandler { private final List records = new ArrayList<>(); diff --git a/src/test/java/org/xbill/DNS/dnssec/TestBase.java b/src/test/java/org/xbill/DNS/dnssec/TestBase.java index 60f53b2c..a5745c78 100644 --- a/src/test/java/org/xbill/DNS/dnssec/TestBase.java +++ b/src/test/java/org/xbill/DNS/dnssec/TestBase.java @@ -1,6 +1,7 @@ // SPDX-License-Identifier: BSD-3-Clause package org.xbill.DNS.dnssec; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.mock; @@ -126,7 +127,18 @@ private void starting(TestInfo description) { Message m; while ((m = messageReader.readMessage(r)) != null) { + for (int i = 0; i < 4; i++) { + assertThat(m.getHeader().getCount(i)) + .withFailMessage("Before normalization") + .isEqualTo(m.getSection(i).size()); + } + m = m.normalize(Message.newQuery(m.getQuestion()), true); + for (int i = 0; i < 4; i++) { + assertThat(m.getHeader().getCount(i)) + .withFailMessage("After normalization") + .isEqualTo(m.getSection(i).size()); + } queryResponsePairs.put(key(m), m); } @@ -286,7 +298,7 @@ protected String getEdeText(Message m) { .flatMap( opt -> opt.getOptions(Code.EDNS_EXTENDED_ERROR).stream() - .filter(o -> o instanceof ExtendedErrorCodeOption) + .filter(ExtendedErrorCodeOption.class::isInstance) .findFirst() .map(o -> ((ExtendedErrorCodeOption) o).getText())) .orElse(null); diff --git a/src/test/java/org/xbill/DNS/hosts/HostsFileParserTest.java b/src/test/java/org/xbill/DNS/hosts/HostsFileParserTest.java index 921e0beb..20700a9b 100644 --- a/src/test/java/org/xbill/DNS/hosts/HostsFileParserTest.java +++ b/src/test/java/org/xbill/DNS/hosts/HostsFileParserTest.java @@ -5,6 +5,8 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; import java.io.BufferedWriter; import java.io.IOException; @@ -18,6 +20,9 @@ import java.nio.file.StandardCopyOption; import java.nio.file.StandardOpenOption; import java.nio.file.attribute.FileTime; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.util.Optional; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -72,7 +77,7 @@ void testMissingFileIsEmptyResult() throws IOException { } @Test - void testCacheLookup() throws IOException { + void testCacheLookupAfterFileDeleteWithoutChangeChecking() throws IOException { Path tempHosts = Files.copy(hostsFileWindows, tempDir, StandardCopyOption.REPLACE_EXISTING); HostsFileParser hostsFileParser = new HostsFileParser(tempHosts, false); assertEquals(0, hostsFileParser.cacheSize()); @@ -98,6 +103,10 @@ void testFileDeletionClearsCache() throws IOException { tempDir.resolve("testFileWatcherClearsCache"), StandardCopyOption.REPLACE_EXISTING); HostsFileParser hostsFileParser = new HostsFileParser(tempHosts); + Clock clock = mock(Clock.class); + hostsFileParser.setClock(clock); + Instant now = Clock.systemUTC().instant(); + when(clock.instant()).thenReturn(now); assertEquals(0, hostsFileParser.cacheSize()); assertEquals( kubernetesAddress, @@ -106,6 +115,7 @@ void testFileDeletionClearsCache() throws IOException { .orElseThrow(() -> new IllegalStateException("Host entry not found"))); assertTrue(hostsFileParser.cacheSize() > 1, "Cache must not be empty"); Files.delete(tempHosts); + when(clock.instant()).thenReturn(now.plus(Duration.ofMinutes(6))); assertEquals(Optional.empty(), hostsFileParser.getAddressForHost(kubernetesName, Type.A)); assertEquals(0, hostsFileParser.cacheSize()); } @@ -119,6 +129,10 @@ void testFileChangeClearsCache() throws IOException { StandardCopyOption.REPLACE_EXISTING); Files.setLastModifiedTime(tempHosts, FileTime.fromMillis(0)); HostsFileParser hostsFileParser = new HostsFileParser(tempHosts); + Clock clock = mock(Clock.class); + hostsFileParser.setClock(clock); + Instant now = Clock.systemUTC().instant(); + when(clock.instant()).thenReturn(now); assertEquals(0, hostsFileParser.cacheSize()); assertEquals( kubernetesAddress, @@ -134,6 +148,7 @@ void testFileChangeClearsCache() throws IOException { } Files.setLastModifiedTime(tempHosts, FileTime.fromMillis(10_0000)); + when(clock.instant()).thenReturn(now.plus(Duration.ofMinutes(6))); assertEquals( InetAddress.getByAddress(testName.toString(), localhostBytes), hostsFileParser