diff --git a/flight/flight-sql-jdbc-core/pom.xml b/flight/flight-sql-jdbc-core/pom.xml
index 237d450eec..02fef3dc1f 100644
--- a/flight/flight-sql-jdbc-core/pom.xml
+++ b/flight/flight-sql-jdbc-core/pom.xml
@@ -47,6 +47,21 @@ under the License.
+
+ io.grpc
+ grpc-api
+
+
+
+ io.grpc
+ grpc-netty
+
+
+
+ io.netty
+ netty-transport
+
+
org.apache.arrow
arrow-memory-core
diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java
index c1b1c8f8e6..cf9804d68b 100644
--- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java
+++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java
@@ -113,6 +113,7 @@ private static ArrowFlightSqlClientHandler createNewClientHandler(
.withRetainCookies(config.retainCookies())
.withRetainAuth(config.retainAuth())
.withCatalog(config.getCatalog())
+ .withConnectTimeout(config.getConnectTimeout())
.build();
} catch (final SQLException e) {
try {
diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
index 0e9c79a090..cbbe223eb8 100644
--- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
+++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java
@@ -17,10 +17,13 @@
package org.apache.arrow.driver.jdbc.client;
import com.google.common.collect.ImmutableMap;
+import io.grpc.netty.NettyChannelBuilder;
+import io.netty.channel.ChannelOption;
import java.io.IOException;
import java.net.URI;
import java.security.GeneralSecurityException;
import java.sql.SQLException;
+import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -36,6 +39,7 @@
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightClientMiddleware;
import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.FlightGrpcUtils;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightRuntimeException;
import org.apache.arrow.flight.FlightStatusCode;
@@ -50,6 +54,7 @@
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.client.ClientCookieMiddleware;
import org.apache.arrow.flight.grpc.CredentialCallOption;
+import org.apache.arrow.flight.grpc.NettyClientBuilder;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo;
import org.apache.arrow.flight.sql.util.TableRef;
@@ -138,12 +143,11 @@ public List getStreams(final FlightInfo flightInfo)
// Clone the builder and then set the new endpoint on it.
// GH-38574: Currently a new FlightClient will be made for each partition that returns a
- // non-empty Location
- // then disposed of. It may be better to cache clients because a server may report the
- // same Locations.
- // It would also be good to identify when the reported location is the same as the
- // original connection's
- // Location and skip creating a FlightClient in that scenario.
+ // non-empty Location then disposed of. It may be better to cache clients because a server
+ // may report the same Locations. It would also be good to identify when the reported
+ // location
+ // is the same as the original connection's Location and skip creating a FlightClient in
+ // that scenario.
List exceptions = new ArrayList<>();
CloseableEndpointStreamPair stream = null;
for (Location location : endpoint.getLocations()) {
@@ -158,7 +162,8 @@ public List getStreams(final FlightInfo flightInfo)
new Builder(ArrowFlightSqlClientHandler.this.builder)
.withHost(endpointUri.getHost())
.withPort(endpointUri.getPort())
- .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS));
+ .withEncryption(endpointUri.getScheme().equals(LocationSchemes.GRPC_TLS))
+ .withConnectTimeout(builder.connectTimeout);
ArrowFlightSqlClientHandler endpointHandler = null;
try {
@@ -177,6 +182,7 @@ public List getStreams(final FlightInfo flightInfo)
exceptions.add(ex);
continue;
}
+
break;
}
if (stream != null) {
@@ -543,6 +549,8 @@ public static final class Builder {
@VisibleForTesting Optional catalog = Optional.empty();
+ @VisibleForTesting @Nullable Duration connectTimeout;
+
// These two middleware are for internal use within build() and should not be exposed by builder
// APIs.
// Note that these middleware may not necessarily be registered.
@@ -825,6 +833,19 @@ public Builder withCatalog(@Nullable final String catalog) {
return this;
}
+ public Builder withConnectTimeout(Duration connectTimeout) {
+ this.connectTimeout = connectTimeout;
+ return this;
+ }
+
+ /** Get the location that this client will connect to. */
+ public Location getLocation() {
+ if (useEncryption) {
+ return Location.forGrpcTls(host, port);
+ }
+ return Location.forGrpcInsecure(host, port);
+ }
+
/**
* Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields.
*
@@ -845,17 +866,15 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
if (isUsingUserPasswordAuth) {
buildTimeMiddlewareFactories.add(authFactory);
}
- final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator);
+ final NettyClientBuilder clientBuilder = new NettyClientBuilder();
+ clientBuilder.allocator(allocator);
buildTimeMiddlewareFactories.add(new ClientCookieMiddleware.Factory());
buildTimeMiddlewareFactories.forEach(clientBuilder::intercept);
- Location location;
if (useEncryption) {
- location = Location.forGrpcTls(host, port);
clientBuilder.useTls();
- } else {
- location = Location.forGrpcInsecure(host, port);
}
+ Location location = getLocation();
clientBuilder.location(location);
if (useEncryption) {
@@ -883,7 +902,14 @@ public ArrowFlightSqlClientHandler build() throws SQLException {
}
}
- client = clientBuilder.build();
+ NettyChannelBuilder channelBuilder = clientBuilder.build();
+ if (connectTimeout != null) {
+ channelBuilder.withOption(
+ ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) connectTimeout.toMillis());
+ }
+ client =
+ FlightGrpcUtils.createFlightClient(
+ allocator, channelBuilder.build(), clientBuilder.middleware());
final ArrayList credentialOptions = new ArrayList<>();
if (isUsingUserPasswordAuth) {
// If the authFactory has already been used for a handshake, use the existing token.
diff --git a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java
index e8bae2a207..ab6a5898b7 100644
--- a/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java
+++ b/flight/flight-sql-jdbc-core/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java
@@ -16,6 +16,7 @@
*/
package org.apache.arrow.driver.jdbc.utils;
+import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
@@ -163,6 +164,16 @@ public String getCatalog() {
return ArrowFlightConnectionProperty.CATALOG.getString(properties);
}
+ /** The initial connect timeout. */
+ public Duration getConnectTimeout() {
+ Integer timeout = ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.getInteger(properties);
+ if (timeout == null) {
+ return Duration.ofMillis(
+ (int) ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS.defaultValue());
+ }
+ return Duration.ofMillis(timeout);
+ }
+
/**
* Gets the {@link CallOption}s from this {@link ConnectionConfig}.
*
@@ -213,7 +224,9 @@ public enum ArrowFlightConnectionProperty implements ConnectionProperty {
TOKEN("token", null, Type.STRING, false),
RETAIN_COOKIES("retainCookies", true, Type.BOOLEAN, false),
RETAIN_AUTH("retainAuth", true, Type.BOOLEAN, false),
- CATALOG("catalog", null, Type.STRING, false);
+ CATALOG("catalog", null, Type.STRING, false),
+ CONNECT_TIMEOUT_MILLIS("connectTimeoutMs", 10000, Type.NUMBER, false),
+ ;
private final String camelName;
private final Object defaultValue;
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java
index a8d04dfc83..cd47408f52 100644
--- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java
@@ -25,12 +25,7 @@
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.jupiter.api.Assertions.assertArrayEquals;
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-import static org.junit.jupiter.api.Assertions.fail;
+import static org.junit.jupiter.api.Assertions.*;
import com.google.common.collect.ImmutableSet;
import java.nio.charset.StandardCharsets;
@@ -645,6 +640,139 @@ public void testFallbackSecondFlightServer() throws Exception {
}
}
+ @Test
+ public void testFallbackUnresolvableFlightServer() throws Exception {
+ final Schema schema =
+ new Schema(
+ Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
+ try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
+ resultData.setRowCount(1);
+ ((IntVector) resultData.getVector(0)).set(0, 1);
+
+ try (final FallbackFlightSqlProducer rootProducer =
+ new FallbackFlightSqlProducer(resultData);
+ FlightServer rootServer =
+ FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer)
+ .build()
+ .start();
+ Connection newConnection =
+ DriverManager.getConnection(
+ String.format(
+ "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false",
+ rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) {
+ // This first attempt should take a measurable amount of time.
+ long start = System.nanoTime();
+ try (Statement newStatement = newConnection.createStatement()) {
+ try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
+ List actualData = new ArrayList<>();
+ while (result.next()) {
+ actualData.add(result.getInt(1));
+ }
+
+ // Assert
+ assertEquals(resultData.getRowCount(), actualData.size());
+ assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
+ }
+ }
+ long attempt1 = System.nanoTime();
+ double elapsedMs = (attempt1 - start) / 1_000_000.;
+ assertTrue(
+ elapsedMs >= 5000.,
+ String.format(
+ "Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs));
+
+ // Once the client cache is implemented (GH-661), this second attempt should take less time,
+ // since the failure from before should be cached.
+ start = System.nanoTime();
+ try (Statement newStatement = newConnection.createStatement()) {
+ try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
+ List actualData = new ArrayList<>();
+ while (result.next()) {
+ actualData.add(result.getInt(1));
+ }
+
+ // Assert
+ assertEquals(resultData.getRowCount(), actualData.size());
+ assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
+ }
+ }
+ attempt1 = System.nanoTime();
+ elapsedMs = (attempt1 - start) / 1_000_000.;
+ // TODO(GH-661): this assertion should be flipped to assertTrue.
+ assertFalse(
+ elapsedMs < 5000.,
+ String.format("Expected second attempt to be the same, but %f ms elapsed", elapsedMs));
+ }
+ }
+ }
+
+ @Test
+ public void testFallbackUnresolvableFlightServerDisableCache() throws Exception {
+ final Schema schema =
+ new Schema(
+ Collections.singletonList(Field.nullable("int_column", Types.MinorType.INT.getType())));
+ try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ VectorSchemaRoot resultData = VectorSchemaRoot.create(schema, allocator)) {
+ resultData.setRowCount(1);
+ ((IntVector) resultData.getVector(0)).set(0, 1);
+
+ try (final FallbackFlightSqlProducer rootProducer =
+ new FallbackFlightSqlProducer(resultData);
+ FlightServer rootServer =
+ FlightServer.builder(allocator, forGrpcInsecure("localhost", 0), rootProducer)
+ .build()
+ .start();
+ Connection newConnection =
+ DriverManager.getConnection(
+ String.format(
+ "jdbc:arrow-flight-sql://%s:%d/?useEncryption=false&useClientCache=false",
+ rootServer.getLocation().getUri().getHost(), rootServer.getPort()))) {
+ // This first attempt should take a measurable amount of time.
+ long start = System.nanoTime();
+ try (Statement newStatement = newConnection.createStatement()) {
+ try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
+ List actualData = new ArrayList<>();
+ while (result.next()) {
+ actualData.add(result.getInt(1));
+ }
+
+ // Assert
+ assertEquals(resultData.getRowCount(), actualData.size());
+ assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
+ }
+ }
+ long attempt1 = System.nanoTime();
+ double elapsedMs = (attempt1 - start) / 1_000_000.;
+ assertTrue(
+ elapsedMs >= 5000.,
+ String.format(
+ "Expected first attempt to hit the timeout, but only %f ms elapsed", elapsedMs));
+
+ // This second attempt should take a long time still, since we disabled the cache.
+ start = System.nanoTime();
+ try (Statement newStatement = newConnection.createStatement()) {
+ try (ResultSet result = newStatement.executeQuery("fallback with unresolvable")) {
+ List actualData = new ArrayList<>();
+ while (result.next()) {
+ actualData.add(result.getInt(1));
+ }
+
+ // Assert
+ assertEquals(resultData.getRowCount(), actualData.size());
+ assertTrue(actualData.contains(((IntVector) resultData.getVector(0)).get(0)));
+ }
+ }
+ attempt1 = System.nanoTime();
+ elapsedMs = (attempt1 - start) / 1_000_000.;
+ assertTrue(
+ elapsedMs >= 5000.,
+ String.format(
+ "Expected second attempt to hit the timeout, but only %f ms elapsed", elapsedMs));
+ }
+ }
+ }
+
@Test
public void testShouldRunSelectQueryWithEmptyVectorsEmbedded() throws Exception {
try (Statement statement = connection.createStatement();
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java
index 6beaba8236..7b416638e1 100644
--- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandlerBuilderTest.java
@@ -147,6 +147,7 @@ public void testDefaults() {
assertNull(builder.clientCertificatePath);
assertNull(builder.clientKeyPath);
assertEquals(Optional.empty(), builder.catalog);
+ assertNull(builder.connectTimeout);
}
@Test
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java
index 4a46b5f5be..c780d53fab 100644
--- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java
@@ -18,6 +18,7 @@
import static java.lang.Runtime.getRuntime;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CATALOG;
+import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.CONNECT_TIMEOUT_MILLIS;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD;
import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT;
@@ -27,6 +28,7 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
+import java.time.Duration;
import java.util.Properties;
import java.util.Random;
import java.util.function.Function;
@@ -59,49 +61,67 @@ public void setUp() {
public void testGetProperty(
ArrowFlightConnectionProperty property,
Object value,
+ Object expected,
Function configFunction) {
properties.put(property.camelName(), value);
arrowFlightConnectionConfigFunction = configFunction;
- assertThat(configFunction.apply(arrowFlightConnectionConfig), is(value));
- assertThat(arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), is(value));
+ assertThat(configFunction.apply(arrowFlightConnectionConfig), is(expected));
+ assertThat(
+ arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), is(expected));
}
public static Stream provideParameters() {
+ int port = RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE));
+ boolean useEncryption = RANDOM.nextBoolean();
+ int threadPoolSize = RANDOM.nextInt(getRuntime().availableProcessors());
return Stream.of(
Arguments.of(
HOST,
"host",
+ "host",
(Function)
ArrowFlightConnectionConfigImpl::getHost),
Arguments.of(
PORT,
- RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE)),
+ port,
+ port,
(Function)
ArrowFlightConnectionConfigImpl::getPort),
Arguments.of(
USER,
"user",
+ "user",
(Function)
ArrowFlightConnectionConfigImpl::getUser),
Arguments.of(
PASSWORD,
"password",
+ "password",
(Function)
ArrowFlightConnectionConfigImpl::getPassword),
Arguments.of(
USE_ENCRYPTION,
- RANDOM.nextBoolean(),
+ useEncryption,
+ useEncryption,
(Function)
ArrowFlightConnectionConfigImpl::useEncryption),
Arguments.of(
THREAD_POOL_SIZE,
- RANDOM.nextInt(getRuntime().availableProcessors()),
+ threadPoolSize,
+ threadPoolSize,
(Function)
ArrowFlightConnectionConfigImpl::threadPoolSize),
Arguments.of(
CATALOG,
"catalog",
+ "catalog",
+ (Function)
+ ArrowFlightConnectionConfigImpl::getCatalog),
+ Arguments.of(
+ CONNECT_TIMEOUT_MILLIS,
+ 5000,
+ Duration.ofMillis(5000),
(Function)
- ArrowFlightConnectionConfigImpl::getCatalog));
+ ArrowFlightConnectionConfigImpl::getConnectTimeout));
}
}
diff --git a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java
index 9aa257172c..670b9e3be0 100644
--- a/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java
+++ b/flight/flight-sql-jdbc-core/src/test/java/org/apache/arrow/driver/jdbc/utils/FallbackFlightSqlProducer.java
@@ -109,6 +109,16 @@ private FlightInfo getFlightInfo(FlightDescriptor descriptor, String query) {
Location.forGrpcInsecure("localhost", 9999),
Location.reuseConnection())
.build());
+ } else if (query.equals("fallback with unresolvable")) {
+ endpoints =
+ Collections.singletonList(
+ FlightEndpoint.builder(
+ ticket,
+ // Inaccessible IP
+ // https://stackoverflow.com/questions/10456044/what-is-a-good-invalid-ip-address-to-use-for-unit-tests
+ Location.forGrpcInsecure("203.0.113.0", 9999),
+ Location.reuseConnection())
+ .build());
} else {
throw CallStatus.UNIMPLEMENTED.withDescription(query).toRuntimeException();
}