Skip to content

Commit 61a3b1c

Browse files
authoredApr 12, 2024··
SNOW-916942: Fix native okta retry logic (#1701)
1 parent e5c4cbd commit 61a3b1c

File tree

5 files changed

+182
-17
lines changed

5 files changed

+182
-17
lines changed
 

‎src/main/java/net/snowflake/client/core/SessionUtil.java

+45-17
Original file line numberDiff line numberDiff line change
@@ -651,16 +651,31 @@ private static SFLoginOutput newSession(
651651
loginInput.getHttpClientSettingsKey());
652652
} catch (SnowflakeSQLException ex) {
653653
if (ex.getErrorCode() == ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT.getMessageCode()) {
654-
if (authenticatorType == ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT) {
655-
SessionUtilKeyPair s =
656-
new SessionUtilKeyPair(
657-
loginInput.getPrivateKey(),
658-
loginInput.getPrivateKeyFile(),
659-
loginInput.getPrivateKeyFilePwd(),
660-
loginInput.getAccountName(),
661-
loginInput.getUserName());
662-
663-
data.put(ClientAuthnParameter.TOKEN.name(), s.issueJwtToken());
654+
if (authenticatorType == ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT
655+
|| authenticatorType == ClientAuthnDTO.AuthenticatorType.OKTA) {
656+
657+
if (authenticatorType == ClientAuthnDTO.AuthenticatorType.SNOWFLAKE_JWT) {
658+
SessionUtilKeyPair s =
659+
new SessionUtilKeyPair(
660+
loginInput.getPrivateKey(),
661+
loginInput.getPrivateKeyFile(),
662+
loginInput.getPrivateKeyFilePwd(),
663+
loginInput.getAccountName(),
664+
loginInput.getUserName());
665+
666+
data.put(ClientAuthnParameter.TOKEN.name(), s.issueJwtToken());
667+
} else if (authenticatorType == ClientAuthnDTO.AuthenticatorType.OKTA) {
668+
logger.debug("Retrieve new token for Okta authentication.");
669+
// If we need to retry, we need to get a new Okta token
670+
tokenOrSamlResponse = getSamlResponseUsingOkta(loginInput);
671+
data.put(ClientAuthnParameter.RAW_SAML_RESPONSE.name(), tokenOrSamlResponse);
672+
authnData.setData(data);
673+
String updatedJson = mapper.writeValueAsString(authnData);
674+
675+
StringEntity updatedInput = new StringEntity(updatedJson, StandardCharsets.UTF_8);
676+
updatedInput.setContentType("application/json");
677+
postRequest.setEntity(updatedInput);
678+
}
664679

665680
long elapsedSeconds = ex.getElapsedSeconds();
666681

@@ -690,7 +705,8 @@ private static SFLoginOutput newSession(
690705
}
691706
}
692707

693-
// JWT renew should not count as a retry, so we pass back the current retry count.
708+
// JWT or Okta renew should not count as a retry, so we pass back the current retry
709+
// count.
694710
retryCount = ex.getRetryCount();
695711

696712
continue;
@@ -1362,12 +1378,24 @@ private static void handleFederatedFlowError(SFLoginInput loginInput, Exception
13621378
*/
13631379
private static String getSamlResponseUsingOkta(SFLoginInput loginInput)
13641380
throws SnowflakeSQLException {
1365-
JsonNode dataNode = federatedFlowStep1(loginInput);
1366-
String tokenUrl = dataNode.path("tokenUrl").asText();
1367-
String ssoUrl = dataNode.path("ssoUrl").asText();
1368-
federatedFlowStep2(loginInput, tokenUrl, ssoUrl);
1369-
final String oneTimeToken = federatedFlowStep3(loginInput, tokenUrl);
1370-
return federatedFlowStep4(loginInput, ssoUrl, oneTimeToken);
1381+
while (true) {
1382+
try {
1383+
JsonNode dataNode = federatedFlowStep1(loginInput);
1384+
String tokenUrl = dataNode.path("tokenUrl").asText();
1385+
String ssoUrl = dataNode.path("ssoUrl").asText();
1386+
federatedFlowStep2(loginInput, tokenUrl, ssoUrl);
1387+
final String oneTimeToken = federatedFlowStep3(loginInput, tokenUrl);
1388+
return federatedFlowStep4(loginInput, ssoUrl, oneTimeToken);
1389+
} catch (SnowflakeSQLException ex) {
1390+
// This error gets thrown if the okta request encountered a retry-able error that
1391+
// requires getting a new one-time token.
1392+
if (ex.getErrorCode() == ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT.getMessageCode()) {
1393+
logger.debug("Failed to get Okta SAML response. Retrying without changing retry count.");
1394+
} else {
1395+
throw ex;
1396+
}
1397+
}
1398+
}
13711399
}
13721400

13731401
/**

‎src/main/java/net/snowflake/client/jdbc/RestRequest.java

+10
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,16 @@ public static CloseableHttpResponse execute(
399399
break;
400400
}
401401

402+
// If this was a request for an Okta one-time token that failed with a retry-able error,
403+
// throw exception to renew the token before trying again.
404+
if (String.valueOf(httpRequest.getURI()).contains("okta.com/api/v1/authn")) {
405+
throw new SnowflakeSQLException(
406+
ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT,
407+
retryCount,
408+
true,
409+
elapsedMilliForTransientIssues / 1000);
410+
}
411+
402412
// Make sure that any authenticator specific info that needs to be
403413
// updated get's updated before the next retry. Ex - JWT token
404414
// Check to see if customer set socket/connect timeout has been reached,

‎src/test/java/net/snowflake/client/core/SessionUtilLatestIT.java

+75
Original file line numberDiff line numberDiff line change
@@ -390,4 +390,79 @@ public void testOktaAuthGetFail() throws Throwable {
390390
assertEquals(SqlState.IO_ERROR, e.getSQLState());
391391
}
392392
}
393+
394+
private SFLoginInput createOktaLoginInput() {
395+
SFLoginInput input = new SFLoginInput();
396+
input.setServerUrl("https://testauth.okta.com");
397+
input.setUserName("MOCK_USERNAME");
398+
input.setPassword("MOCK_PASSWORD");
399+
input.setAccountName("MOCK_ACCOUNT_NAME");
400+
input.setAppId("MOCK_APP_ID");
401+
input.setOCSPMode(OCSPMode.FAIL_OPEN);
402+
input.setHttpClientSettingsKey(new HttpClientSettingsKey(OCSPMode.FAIL_OPEN));
403+
input.setLoginTimeout(1000);
404+
input.setSessionParameters(new HashMap<>());
405+
input.setAuthenticator("https://testauth.okta.com");
406+
return input;
407+
}
408+
409+
// Testing retry with Okta calls the service to get a new unique token. This is valid after
410+
// version 3.15.1.
411+
@Test
412+
public void testOktaAuthRetry() throws Throwable {
413+
SFLoginInput loginInput = createOktaLoginInput();
414+
Map<SFSessionProperty, Object> connectionPropertiesMap = initConnectionPropertiesMap();
415+
SnowflakeSQLException ex =
416+
new SnowflakeSQLException(ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT, 0, true, 0);
417+
try (MockedStatic<HttpUtil> mockedHttpUtil = mockStatic(HttpUtil.class)) {
418+
mockedHttpUtil
419+
.when(
420+
() ->
421+
HttpUtil.executeGeneralRequest(
422+
Mockito.any(HttpPost.class),
423+
Mockito.anyInt(),
424+
Mockito.anyInt(),
425+
Mockito.anyInt(),
426+
Mockito.anyInt(),
427+
Mockito.nullable(HttpClientSettingsKey.class)))
428+
.thenReturn(
429+
"{\"data\":{\"tokenUrl\":\"https://testauth.okta.com/api/v1/authn\","
430+
+ "\"ssoUrl\":\"https://testauth.okta.com/app/snowflake/abcdefghijklmnopqrstuvwxyz/sso/saml\","
431+
+ "\"proofKey\":null},\"code\":null,\"message\":null,\"success\":true}")
432+
.thenThrow(ex)
433+
.thenReturn(
434+
"{\"data\":{\"tokenUrl\":\"https://testauth.okta.com/api/v1/authn\","
435+
+ "\"ssoUrl\":\"https://testauth.okta.com/app/snowflake/abcdefghijklmnopqrstuvwxyz/sso/saml\","
436+
+ "\"proofKey\":null},\"code\":null,\"message\":null,\"success\":true}");
437+
438+
mockedHttpUtil
439+
.when(
440+
() ->
441+
HttpUtil.executeRequestWithoutCookies(
442+
Mockito.any(HttpRequestBase.class),
443+
Mockito.anyInt(),
444+
Mockito.anyInt(),
445+
Mockito.anyInt(),
446+
Mockito.anyInt(),
447+
Mockito.anyInt(),
448+
Mockito.nullable(AtomicBoolean.class),
449+
Mockito.nullable(HttpClientSettingsKey.class)))
450+
.thenReturn(
451+
"{\"expiresAt\":\"2023-10-13T19:18:09.000Z\",\"status\":\"SUCCESS\",\"sessionToken\":\"testsessiontoken\"}");
452+
453+
mockedHttpUtil
454+
.when(
455+
() ->
456+
HttpUtil.executeGeneralRequest(
457+
Mockito.any(HttpGet.class),
458+
Mockito.anyInt(),
459+
Mockito.anyInt(),
460+
Mockito.anyInt(),
461+
Mockito.anyInt(),
462+
Mockito.nullable(HttpClientSettingsKey.class)))
463+
.thenReturn("<body><form action=\"https://testauth.okta.com\"></form></body>");
464+
465+
SessionUtil.openSession(loginInput, connectionPropertiesMap, "ALL");
466+
}
467+
}
393468
}

‎src/test/java/net/snowflake/client/jdbc/ConnectionLatestIT.java

+38
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,10 @@
3838
import java.sql.SQLException;
3939
import java.sql.Statement;
4040
import java.time.Duration;
41+
import java.util.ArrayList;
4142
import java.util.Collections;
4243
import java.util.Enumeration;
44+
import java.util.List;
4345
import java.util.Map;
4446
import java.util.Properties;
4547
import java.util.concurrent.TimeUnit;
@@ -62,6 +64,7 @@
6264
import org.apache.http.entity.StringEntity;
6365
import org.junit.After;
6466
import org.junit.Before;
67+
import org.junit.Ignore;
6568
import org.junit.Rule;
6669
import org.junit.Test;
6770
import org.junit.experimental.categories.Category;
@@ -1150,4 +1153,39 @@ public void testIsAsyncSession() throws SQLException, InterruptedException {
11501153
assertFalse(snowflakeConnection.getSfSession().isAsyncSession());
11511154
}
11521155
}
1156+
1157+
// Test for regenerating okta one-time token for versions > 3.15.1
1158+
@Test
1159+
@Ignore
1160+
public void testDataSourceOktaGenerates429StatusCode() throws Exception {
1161+
// test with username/password authentication
1162+
// set up DataSource object and ensure connection works
1163+
Map<String, String> params = getConnectionParameters();
1164+
SnowflakeBasicDataSource ds = new SnowflakeBasicDataSource();
1165+
ds.setServerName(params.get("host"));
1166+
ds.setSsl("on".equals(params.get("ssl")));
1167+
ds.setAccount(params.get("account"));
1168+
ds.setPortNumber(Integer.parseInt(params.get("port")));
1169+
ds.setUser(params.get("ssoUser"));
1170+
ds.setPassword(params.get("ssoPassword"));
1171+
ds.setAuthenticator("<okta address>");
1172+
Runnable r =
1173+
() -> {
1174+
try {
1175+
ds.getConnection();
1176+
} catch (SQLException e) {
1177+
throw new RuntimeException(e);
1178+
}
1179+
};
1180+
List<Thread> threadList = new ArrayList<>();
1181+
for (int i = 0;
1182+
i < 30;
1183+
++i) { // https://docs.snowflake.com/en/user-guide/admin-security-fed-auth-use#http-429-errors
1184+
threadList.add(new Thread(r));
1185+
}
1186+
threadList.forEach(Thread::start);
1187+
for (Thread thread : threadList) {
1188+
thread.join();
1189+
}
1190+
}
11531191
}

‎src/test/java/net/snowflake/client/jdbc/RestRequestTest.java

+14
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,20 @@ public void testExceptionAuthBasedTimeout() throws IOException {
359359
}
360360
}
361361

362+
@Test
363+
public void testExceptionAuthBasedTimeoutFor429ErrorCode() throws IOException {
364+
CloseableHttpClient client = mock(CloseableHttpClient.class);
365+
when(client.execute(any(HttpUriRequest.class)))
366+
.thenAnswer((Answer<CloseableHttpResponse>) invocation -> retryLoginResponse());
367+
368+
try {
369+
execute(client, "login-request.com/?requestId=abcd-1234", 2, 1, 30000, true, false);
370+
} catch (SnowflakeSQLException ex) {
371+
assertThat(
372+
ex.getErrorCode(), equalTo(ErrorCode.AUTHENTICATOR_REQUEST_TIMEOUT.getMessageCode()));
373+
}
374+
}
375+
362376
@Test
363377
public void testNoRetry() throws IOException, SnowflakeSQLException {
364378
boolean telemetryEnabled = TelemetryService.getInstance().isEnabled();

0 commit comments

Comments
 (0)
Please sign in to comment.