Skip to content

Fix OIDC tests. #1753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.mongodb.internal.connection;

import com.mongodb.ClusterFixture;
import com.mongodb.ConnectionString;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoCommandException;
Expand All @@ -41,11 +42,11 @@
import org.bson.Document;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;

import java.io.IOException;
import java.lang.reflect.Field;
Expand Down Expand Up @@ -79,7 +80,6 @@
import static com.mongodb.MongoCredential.TOKEN_RESOURCE_KEY;
import static com.mongodb.assertions.Assertions.assertNotNull;
import static com.mongodb.testing.MongoAssertions.assertCause;
import static java.lang.Math.min;
import static java.lang.String.format;
import static java.lang.System.getenv;
import static java.util.Arrays.asList;
Expand Down Expand Up @@ -215,9 +215,9 @@ public void test2p1ValidCallbackInputs() {
+ " expectedTimeoutThreshold={3}")
@MethodSource
void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName,
final int timeoutMs,
final int serverSelectionTimeoutMS,
final int expectedTimeoutThreshold) {
final long timeoutMs,
final long serverSelectionTimeoutMS,
final long expectedTimeoutThreshold) {
TestCallback callback1 = createCallback();

OidcCallback callback2 = (context) -> {
Expand All @@ -242,40 +242,50 @@ void testValidCallbackInputsTimeoutWhenTimeoutMsIsSet(final String testName,
assertEquals(1, callback1.getInvocations());
long elapsed = msElapsedSince(start);

assertFalse(elapsed > (timeoutMs == 0 ? serverSelectionTimeoutMS : min(serverSelectionTimeoutMS, timeoutMs)),

assertFalse(elapsed > minTimeout(timeoutMs, serverSelectionTimeoutMS),
format("Elapsed time %d is greater then minimum of serverSelectionTimeoutMS and timeoutMs, which is %d. "
+ "This indicates that the callback was not called with the expected timeout.",
min(serverSelectionTimeoutMS, timeoutMs),
elapsed));
elapsed,
minTimeout(timeoutMs, serverSelectionTimeoutMS)));

}
}

private static Stream<Arguments> testValidCallbackInputsTimeoutWhenTimeoutMsIsSet() {
long rtt = ClusterFixture.getPrimaryRTT();
return Stream.of(
Arguments.of("serverSelectionTimeoutMS honored for oidc callback if it's lower than timeoutMS",
1000, // timeoutMS
500, // serverSelectionTimeoutMS
499), // expectedTimeoutThreshold
1000 + rtt, // timeoutMS
500 + rtt, // serverSelectionTimeoutMS
499 + rtt), // expectedTimeoutThreshold
Arguments.of("timeoutMS honored for oidc callback if it's lower than serverSelectionTimeoutMS",
500, // timeoutMS
1000, // serverSelectionTimeoutMS
499), // expectedTimeoutThreshold
500 + rtt, // timeoutMS
1000 + rtt, // serverSelectionTimeoutMS
499 + rtt), // expectedTimeoutThreshold
Arguments.of("timeoutMS honored for oidc callback if serverSelectionTimeoutMS is infinite",
500 + rtt, // timeoutMS
-1, // serverSelectionTimeoutMS
499 + rtt), // expectedTimeoutThreshold,
Arguments.of("serverSelectionTimeoutMS honored for oidc callback if timeoutMS=0",
0, // infinite timeoutMS
500, // serverSelectionTimeoutMS
499) // expectedTimeoutThreshold
500 + rtt, // serverSelectionTimeoutMS
499 + rtt) // expectedTimeoutThreshold
);
}

// Not a prose test
@ParameterizedTest(name = "test callback timeout when server selection timeout is "
+ "infinite and timeoutMs is set to {0}")
@ValueSource(ints = {0, 100})
void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final int timeoutMs) {
@Test
@DisplayName("test callback timeout when serverSelectionTimeoutMS and timeoutMS are infinite")
void testCallbackTimeoutWhenServerSelectionTimeoutMsIsInfiniteTimeoutMsIsSet() {
TestCallback callback1 = createCallback();
Duration expectedTimeout = ChronoUnit.FOREVER.getDuration();

OidcCallback callback2 = (context) -> {
assertEquals(context.getTimeout(), ChronoUnit.FOREVER.getDuration());
assertEquals(expectedTimeout, context.getTimeout(),
format("Expected timeout to be infinite (%s), but was %s",
expectedTimeout, context.getTimeout()));

return callback1.onRequest(context);
};

Expand All @@ -284,7 +294,7 @@ void testCallbackTimeoutWhenServerSelectionTimeoutIsInfiniteTimeoutMsIsSet(final
builder.serverSelectionTimeout(
-1, // -1 means infinite
TimeUnit.MILLISECONDS))
.timeout(timeoutMs, TimeUnit.MILLISECONDS)
.timeout(0, TimeUnit.MILLISECONDS)
.build();

try (MongoClient mongoClient = createMongoClient(clientSettings)) {
Expand Down Expand Up @@ -1242,4 +1252,10 @@ public TestCallback createHumanCallback() {
private long msElapsedSince(final long timeOfStart) {
return TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - timeOfStart);
}

private static long minTimeout(final long timeoutMs, final long serverSelectionTimeoutMS) {
long timeoutMsEffective = timeoutMs != 0 ? timeoutMs : Long.MAX_VALUE;
long serverSelectionTimeoutMSEffective = serverSelectionTimeoutMS != -1 ? serverSelectionTimeoutMS : Long.MAX_VALUE;
return Math.min(timeoutMsEffective, serverSelectionTimeoutMSEffective);
}
}