diff --git a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java index 93c062f13f..51ed12210f 100644 --- a/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java +++ b/driver-sync/src/test/functional/com/mongodb/internal/connection/OidcAuthenticationProseTests.java @@ -16,6 +16,7 @@ package com.mongodb.internal.connection; +import com.mongodb.ClusterFixture; import com.mongodb.ConnectionString; import com.mongodb.MongoClientSettings; import com.mongodb.MongoCommandException; @@ -41,11 +42,11 @@ import org.bson.Document; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.junit.jupiter.params.provider.ValueSource; import java.io.IOException; import java.lang.reflect.Field; @@ -79,7 +80,6 @@ import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY; import static com.mongodb.assertions.Assertions.assertNotNull; import static com.mongodb.testing.MongoAssertions.assertCause; -import static java.lang.Math.min; import static java.lang.String.format; import static java.lang.System.getenv; import static java.util.Arrays.asList; @@ -215,9 +215,9 @@ public void test2p1ValidCallbackInputs() { + " expectedTimeoutThreshold={3}") @MethodSource void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName, - final int timeoutMs, - final int serverSelectionTimeoutMS, - final int expectedTimeoutThreshold) { + final long timeoutMs, + final long serverSelectionTimeoutMS, + final long expectedTimeoutThreshold) { TestCallback callback1 = createCallback(); OidcCallback callback2 = (context) -> { @@ -242,40 +242,50 @@ void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName, assertEquals(1, callback1.getInvocations()); long elapsed = msElapsedSince(start); - assertFalse(elapsed > (timeoutMs == 0 ? serverSelectionTimeoutMS : min(serverSelectionTimeoutMS, timeoutMs)), + + assertFalse(elapsed > minTimeout(timeoutMs, serverSelectionTimeoutMS), format("Elapsed time %d is greater then minimum of serverSelectionTimeoutMS and timeoutMs, which is %d. " + "This indicates that the callback was not called with the expected timeout.", - min(serverSelectionTimeoutMS, timeoutMs), - elapsed)); + elapsed, + minTimeout(timeoutMs, serverSelectionTimeoutMS))); + } } private static Stream testValidCallbackInputsTimeoutWhenTimeoutMsIsSet() { + long rtt = ClusterFixture.getPrimaryRTT(); return Stream.of( Arguments.of("serverSelectionTimeoutMS honored for oidc callback if it's lower than timeoutMS", - 1000, // timeoutMS - 500, // serverSelectionTimeoutMS - 499), // expectedTimeoutThreshold + 1000 + rtt, // timeoutMS + 500 + rtt, // serverSelectionTimeoutMS + 499 + rtt), // expectedTimeoutThreshold Arguments.of("timeoutMS honored for oidc callback if it's lower than serverSelectionTimeoutMS", - 500, // timeoutMS - 1000, // serverSelectionTimeoutMS - 499), // expectedTimeoutThreshold + 500 + rtt, // timeoutMS + 1000 + rtt, // serverSelectionTimeoutMS + 499 + rtt), // expectedTimeoutThreshold + Arguments.of("timeoutMS honored for oidc callback if serverSelectionTimeoutMS is infinite", + 500 + rtt, // timeoutMS + -1, // serverSelectionTimeoutMS + 499 + rtt), // expectedTimeoutThreshold, Arguments.of("serverSelectionTimeoutMS honored for oidc callback if timeoutMS=0", 0, // infinite timeoutMS - 500, // serverSelectionTimeoutMS - 499) // expectedTimeoutThreshold + 500 + rtt, // serverSelectionTimeoutMS + 499 + rtt) // expectedTimeoutThreshold ); } // Not a prose test - @ParameterizedTest(name = "test callback timeout when server selection timeout is " - + "infinite and timeoutMs is set to {0}") - @ValueSource(ints = {0, 100}) - void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final int timeoutMs) { + @Test + @DisplayName("test callback timeout when serverSelectionTimeoutMS and timeoutMS are infinite") + void testCallbackTimeoutWhenServerSelectionTimeoutMsIsInfiniteTimeoutMsIsSet() { TestCallback callback1 = createCallback(); + Duration expectedTimeout = ChronoUnit.FOREVER.getDuration(); OidcCallback callback2 = (context) -> { - assertEquals(context.getTimeout(), ChronoUnit.FOREVER.getDuration()); + assertEquals(expectedTimeout, context.getTimeout(), + format("Expected timeout to be infinite (%s), but was %s", + expectedTimeout, context.getTimeout())); + return callback1.onRequest(context); }; @@ -284,7 +294,7 @@ void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final builder.serverSelectionTimeout( -1, // -1 means infinite TimeUnit.MILLISECONDS)) - .timeout(timeoutMs, TimeUnit.MILLISECONDS) + .timeout(0, TimeUnit.MILLISECONDS) .build(); try (MongoClient mongoClient = createMongoClient(clientSettings)) { @@ -1242,4 +1252,10 @@ public TestCallback createHumanCallback() { private long msElapsedSince(final long timeOfStart) { return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - timeOfStart); } + + private static long minTimeout(final long timeoutMs, final long serverSelectionTimeoutMS) { + long timeoutMsEffective = timeoutMs != 0 ? timeoutMs : Long.MAX_VALUE; + long serverSelectionTimeoutMSEffective = serverSelectionTimeoutMS != -1 ? serverSelectionTimeoutMS : Long.MAX_VALUE; + return Math.min(timeoutMsEffective, serverSelectionTimeoutMSEffective); + } }