diff --git a/src/main/java/com/flagsmith/config/FlagsmithConfig.java b/src/main/java/com/flagsmith/config/FlagsmithConfig.java index ffb7357f..33d8cd1f 100644 --- a/src/main/java/com/flagsmith/config/FlagsmithConfig.java +++ b/src/main/java/com/flagsmith/config/FlagsmithConfig.java @@ -11,6 +11,7 @@ import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.X509TrustManager; import lombok.Getter; +import okhttp3.ConnectionPool; import okhttp3.HttpUrl; import okhttp3.Interceptor; import okhttp3.OkHttpClient; @@ -64,6 +65,9 @@ protected FlagsmithConfig(Builder builder) { if (builder.proxy != null) { httpBuilder.proxy(builder.proxy); } + if (builder.connectionPool != null) { + httpBuilder.connectionPool(builder.connectionPool); + } if (!builder.supportedProtocols.isEmpty()) { httpBuilder.protocols( builder.supportedProtocols.stream() @@ -110,6 +114,7 @@ public static class Builder { private X509TrustManager trustManager; private FlagsmithFlagDefaults flagsmithFlagDefaults; private AnalyticsProcessor analyticsProcessor; + private ConnectionPool connectionPool; private Boolean enableLocalEvaluation = Boolean.FALSE; private Integer environmentRefreshIntervalSeconds = DEFAULT_ENVIRONMENT_REFRESH_SECONDS; @@ -203,6 +208,18 @@ public Builder withProxy(Proxy proxy) { return this; } + /** + * Provide a custom OkHttp ConnectionPool to tune keep-alive duration and maximum idle + * connections. When not set, OkHttp's defaults are used. + * + * @param connectionPool the ConnectionPool to use for the HTTP client + * @return the Builder + */ + public Builder connectionPool(ConnectionPool connectionPool) { + this.connectionPool = connectionPool; + return this; + } + /** * Add retries for HTTP request to the builder. * diff --git a/src/test/java/com/flagsmith/config/FlagsmithConfigTest.java b/src/test/java/com/flagsmith/config/FlagsmithConfigTest.java index 12c4a7d3..a2524964 100644 --- a/src/test/java/com/flagsmith/config/FlagsmithConfigTest.java +++ b/src/test/java/com/flagsmith/config/FlagsmithConfigTest.java @@ -2,12 +2,15 @@ import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import com.flagsmith.config.FlagsmithConfig.Protocol; import java.net.InetSocketAddress; import java.net.Proxy; import java.util.Collections; +import java.util.concurrent.TimeUnit; +import okhttp3.ConnectionPool; import okhttp3.mock.MockInterceptor; import org.junit.jupiter.api.Test; @@ -53,6 +56,29 @@ public void configTest_multipleInterceptors() { assertEquals(2, flagsmithConfig.getHttpClient().interceptors().size()); } + @Test + public void configTest_customConnectionPool_respectsKeepAliveAndMaxIdle() throws Exception { + final ConnectionPool pool = new ConnectionPool(7, 42, TimeUnit.SECONDS); + + final FlagsmithConfig flagsmithConfig = FlagsmithConfig.newBuilder() + .connectionPool(pool) + .build(); + + final ConnectionPool wired = flagsmithConfig.getHttpClient().connectionPool(); + final Object delegate = readField(wired, "delegate"); + assertEquals(7, readInt(delegate, "maxIdleConnections")); + assertEquals(TimeUnit.SECONDS.toNanos(42), readLong(delegate, "keepAliveDurationNs")); + } + + @Test + public void configTest_nullConnectionPool_isSafeAndUsesDefault() { + final FlagsmithConfig flagsmithConfig = FlagsmithConfig.newBuilder() + .connectionPool(null) + .build(); + + assertNotNull(flagsmithConfig.getHttpClient().connectionPool()); + } + @Test public void configTest_supportedProtocols() { final FlagsmithConfig defaultFlagsmithConfig = FlagsmithConfig.newBuilder().build(); @@ -65,4 +91,18 @@ public void configTest_supportedProtocols() { assertEquals(1, customFlagsmithConfig.getHttpClient().protocols().size()); assertEquals(okhttp3.Protocol.HTTP_1_1, customFlagsmithConfig.getHttpClient().protocols().get(0)); } + + private static Object readField(Object target, String fieldName) throws Exception { + java.lang.reflect.Field field = target.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + return field.get(target); + } + + private static int readInt(Object target, String fieldName) throws Exception { + return (int) readField(target, fieldName); + } + + private static long readLong(Object target, String fieldName) throws Exception { + return (long) readField(target, fieldName); + } }