@@ -651,16 +651,31 @@ private static SFLoginOutput newSession(
651
651
loginInput .getHttpClientSettingsKey ());
652
652
} catch (SnowflakeSQLException ex ) {
653
653
if (ex .getErrorCode () == ErrorCode .AUTHENTICATOR_REQUEST_TIMEOUT .getMessageCode ()) {
654
- if (authenticatorType == ClientAuthnDTO .AuthenticatorType .SNOWFLAKE_JWT ) {
655
- SessionUtilKeyPair s =
656
- new SessionUtilKeyPair (
657
- loginInput .getPrivateKey (),
658
- loginInput .getPrivateKeyFile (),
659
- loginInput .getPrivateKeyFilePwd (),
660
- loginInput .getAccountName (),
661
- loginInput .getUserName ());
662
-
663
- data .put (ClientAuthnParameter .TOKEN .name (), s .issueJwtToken ());
654
+ if (authenticatorType == ClientAuthnDTO .AuthenticatorType .SNOWFLAKE_JWT
655
+ || authenticatorType == ClientAuthnDTO .AuthenticatorType .OKTA ) {
656
+
657
+ if (authenticatorType == ClientAuthnDTO .AuthenticatorType .SNOWFLAKE_JWT ) {
658
+ SessionUtilKeyPair s =
659
+ new SessionUtilKeyPair (
660
+ loginInput .getPrivateKey (),
661
+ loginInput .getPrivateKeyFile (),
662
+ loginInput .getPrivateKeyFilePwd (),
663
+ loginInput .getAccountName (),
664
+ loginInput .getUserName ());
665
+
666
+ data .put (ClientAuthnParameter .TOKEN .name (), s .issueJwtToken ());
667
+ } else if (authenticatorType == ClientAuthnDTO .AuthenticatorType .OKTA ) {
668
+ logger .debug ("Retrieve new token for Okta authentication." );
669
+ // If we need to retry, we need to get a new Okta token
670
+ tokenOrSamlResponse = getSamlResponseUsingOkta (loginInput );
671
+ data .put (ClientAuthnParameter .RAW_SAML_RESPONSE .name (), tokenOrSamlResponse );
672
+ authnData .setData (data );
673
+ String updatedJson = mapper .writeValueAsString (authnData );
674
+
675
+ StringEntity updatedInput = new StringEntity (updatedJson , StandardCharsets .UTF_8 );
676
+ updatedInput .setContentType ("application/json" );
677
+ postRequest .setEntity (updatedInput );
678
+ }
664
679
665
680
long elapsedSeconds = ex .getElapsedSeconds ();
666
681
@@ -690,7 +705,8 @@ private static SFLoginOutput newSession(
690
705
}
691
706
}
692
707
693
- // JWT renew should not count as a retry, so we pass back the current retry count.
708
+ // JWT or Okta renew should not count as a retry, so we pass back the current retry
709
+ // count.
694
710
retryCount = ex .getRetryCount ();
695
711
696
712
continue ;
@@ -1362,12 +1378,24 @@ private static void handleFederatedFlowError(SFLoginInput loginInput, Exception
1362
1378
*/
1363
1379
private static String getSamlResponseUsingOkta (SFLoginInput loginInput )
1364
1380
throws SnowflakeSQLException {
1365
- JsonNode dataNode = federatedFlowStep1 (loginInput );
1366
- String tokenUrl = dataNode .path ("tokenUrl" ).asText ();
1367
- String ssoUrl = dataNode .path ("ssoUrl" ).asText ();
1368
- federatedFlowStep2 (loginInput , tokenUrl , ssoUrl );
1369
- final String oneTimeToken = federatedFlowStep3 (loginInput , tokenUrl );
1370
- return federatedFlowStep4 (loginInput , ssoUrl , oneTimeToken );
1381
+ while (true ) {
1382
+ try {
1383
+ JsonNode dataNode = federatedFlowStep1 (loginInput );
1384
+ String tokenUrl = dataNode .path ("tokenUrl" ).asText ();
1385
+ String ssoUrl = dataNode .path ("ssoUrl" ).asText ();
1386
+ federatedFlowStep2 (loginInput , tokenUrl , ssoUrl );
1387
+ final String oneTimeToken = federatedFlowStep3 (loginInput , tokenUrl );
1388
+ return federatedFlowStep4 (loginInput , ssoUrl , oneTimeToken );
1389
+ } catch (SnowflakeSQLException ex ) {
1390
+ // This error gets thrown if the okta request encountered a retry-able error that
1391
+ // requires getting a new one-time token.
1392
+ if (ex .getErrorCode () == ErrorCode .AUTHENTICATOR_REQUEST_TIMEOUT .getMessageCode ()) {
1393
+ logger .debug ("Failed to get Okta SAML response. Retrying without changing retry count." );
1394
+ } else {
1395
+ throw ex ;
1396
+ }
1397
+ }
1398
+ }
1371
1399
}
1372
1400
1373
1401
/**
0 commit comments