diff --git a/src/main/java/net/snowflake/client/config/SFClientConfigParser.java b/src/main/java/net/snowflake/client/config/SFClientConfigParser.java index 9d93aed0a..7b05150d9 100644 --- a/src/main/java/net/snowflake/client/config/SFClientConfigParser.java +++ b/src/main/java/net/snowflake/client/config/SFClientConfigParser.java @@ -58,11 +58,13 @@ public static SFClientConfig loadSFClientConfig(String configFilePath) throws IO derivedConfigFilePath = driverLocation; } else { // 4. Read SF_CLIENT_CONFIG_FILE_NAME if it is present in user home directory. - String userHomeFilePath = - Paths.get(systemGetProperty("user.home"), SF_CLIENT_CONFIG_FILE_NAME).toString(); - if (Files.exists(Paths.get(userHomeFilePath))) { - logger.info("Using config file specified from home directory: {}", userHomeFilePath); - derivedConfigFilePath = userHomeFilePath; + String homeDirectory = systemGetProperty("user.home"); + if (homeDirectory != null) { + String userHomeFilePath = Paths.get(homeDirectory, SF_CLIENT_CONFIG_FILE_NAME).toString(); + if (Files.exists(Paths.get(userHomeFilePath))) { + logger.info("Using config file specified from home directory: {}", userHomeFilePath); + derivedConfigFilePath = userHomeFilePath; + } } } } diff --git a/src/main/java/net/snowflake/client/core/FileCacheManager.java b/src/main/java/net/snowflake/client/core/FileCacheManager.java index 2d00d6d0a..c8cc36089 100644 --- a/src/main/java/net/snowflake/client/core/FileCacheManager.java +++ b/src/main/java/net/snowflake/client/core/FileCacheManager.java @@ -4,6 +4,7 @@ package net.snowflake.client.core; +import static net.snowflake.client.core.FileUtil.isWritable; import static net.snowflake.client.jdbc.SnowflakeUtil.isWindows; import static net.snowflake.client.jdbc.SnowflakeUtil.systemGetEnv; import static net.snowflake.client.jdbc.SnowflakeUtil.systemGetProperty; @@ -27,6 +28,7 @@ import java.nio.file.attribute.PosixFilePermission; import java.nio.file.attribute.PosixFilePermissions; import java.util.Date; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.Stream; import net.snowflake.client.log.SFLogger; @@ -43,7 +45,6 @@ class FileCacheManager { private String cacheDirectorySystemProperty; private String cacheDirectoryEnvironmentVariable; private String baseCacheFileName; - private long cacheExpirationInMilliseconds; private long cacheFileLockExpirationInMilliseconds; private File cacheFile; @@ -74,12 +75,6 @@ FileCacheManager setBaseCacheFileName(String baseCacheFileName) { return this; } - FileCacheManager setCacheExpirationInSeconds(long cacheExpirationInSeconds) { - // converting from seconds to milliseconds - this.cacheExpirationInMilliseconds = cacheExpirationInSeconds * 1000; - return this; - } - FileCacheManager setCacheFileLockExpirationInSeconds(long cacheFileLockExpirationInSeconds) { this.cacheFileLockExpirationInMilliseconds = cacheFileLockExpirationInSeconds * 1000; return this; @@ -90,17 +85,22 @@ FileCacheManager setOnlyOwnerPermissions(boolean onlyOwnerPermissions) { return this; } + synchronized String getCacheFilePath() { + return cacheFile.getAbsolutePath(); + } + /** * Override the cache file. * * @param newCacheFile a file object to override the default one. */ - void overrideCacheFile(File newCacheFile) { + synchronized void overrideCacheFile(File newCacheFile) { if (!newCacheFile.exists()) { logger.debug("Cache file doesn't exists. File: {}", newCacheFile); } if (onlyOwnerPermissions) { - FileUtil.throwWhenPermiossionDifferentThanReadWriteForOwner( + FileUtil.throwWhenFilePermissionsWiderThanUserOnly(newCacheFile, "Override cache file"); + FileUtil.throwWhenParentDirectoryPermissionsWiderThanUserOnly( newCacheFile, "Override cache file"); } else { FileUtil.logFileUsage(cacheFile, "Override cache file", false); @@ -110,7 +110,7 @@ void overrideCacheFile(File newCacheFile) { this.baseCacheFileName = newCacheFile.getName(); } - FileCacheManager build() { + synchronized FileCacheManager build() { // try to get cacheDir from system property or environment variable String cacheDirPath = this.cacheDirectorySystemProperty != null @@ -134,32 +134,30 @@ FileCacheManager build() { if (cacheDirPath != null) { this.cacheDir = new File(cacheDirPath); } else { - // use user home directory to store the cache file - String homeDir = systemGetProperty("user.home"); - if (homeDir != null) { - // Checking if home directory is writable. - File homeFile = new File(homeDir); - if (!homeFile.canWrite()) { - logger.debug("Home directory not writeable, skip using cache", false); - homeDir = null; - } - } - if (homeDir == null) { - // if still home directory is null, no cache dir is set. + this.cacheDir = getDefaultCacheDir(); + } + if (cacheDir == null) { + return this; + } + if (!cacheDir.exists()) { + try { + Files.createDirectories( + cacheDir.toPath(), + PosixFilePermissions.asFileAttribute( + Stream.of( + PosixFilePermission.OWNER_READ, + PosixFilePermission.OWNER_WRITE, + PosixFilePermission.OWNER_EXECUTE) + .collect(Collectors.toSet()))); + } catch (IOException e) { + logger.info( + "Failed to create the cache directory: {}. Ignored. {}", + e.getMessage(), + cacheDir.getAbsoluteFile()); return this; } - if (Constants.getOS() == Constants.OS.WINDOWS) { - this.cacheDir = - new File( - new File(new File(new File(homeDir, "AppData"), "Local"), "Snowflake"), "Caches"); - } else if (Constants.getOS() == Constants.OS.MAC) { - this.cacheDir = new File(new File(new File(homeDir, "Library"), "Caches"), "Snowflake"); - } else { - this.cacheDir = new File(new File(homeDir, ".cache"), "snowflake"); - } } - - if (!this.cacheDir.mkdirs() && !this.cacheDir.exists()) { + if (!this.cacheDir.exists()) { logger.debug( "Cannot create the cache directory {}. Giving up.", this.cacheDir.getAbsolutePath()); return this; @@ -199,12 +197,72 @@ FileCacheManager build() { return this; } - /** Reads the cache file. */ - JsonNode readCacheFile() { - if (cacheFile == null || !this.checkCacheLockFile()) { - // no cache or the cache is not valid. + static File getDefaultCacheDir() { + if (Constants.getOS() == Constants.OS.LINUX) { + String xdgCacheHome = getXdgCacheHome(); + if (xdgCacheHome != null) { + return new File(xdgCacheHome, "snowflake"); + } + } + + String homeDir = getHomeDirProperty(); + if (homeDir == null) { + // if still home directory is null, no cache dir is set. + return null; + } + if (Constants.getOS() == Constants.OS.WINDOWS) { + return new File( + new File(new File(new File(homeDir, "AppData"), "Local"), "Snowflake"), "Caches"); + } else if (Constants.getOS() == Constants.OS.MAC) { + return new File(new File(new File(homeDir, "Library"), "Caches"), "Snowflake"); + } else { + return new File(new File(homeDir, ".cache"), "snowflake"); + } + } + + private static String getXdgCacheHome() { + String xdgCacheHome = systemGetEnv("XDG_CACHE_HOME"); + if (xdgCacheHome != null && isWritable(xdgCacheHome)) { + return xdgCacheHome; + } + return null; + } + + private static String getHomeDirProperty() { + String homeDir = systemGetProperty("user.home"); + if (homeDir != null && isWritable(homeDir)) { + return homeDir; + } + return null; + } + + synchronized T withLock(Supplier supplier) { + if (cacheFile == null) { + logger.error("No cache file assigned", false); + return null; + } + if (cacheLockFile == null) { + logger.error("No cache lock file assigned", false); return null; + } else if (cacheLockFile.exists()) { + deleteCacheLockIfExpired(); } + + if (!tryToLockCacheFile()) { + logger.debug("Failed to lock the file. Skipping cache operation", false); + return null; + } + try { + return supplier.get(); + } finally { + if (!unlockCacheFile()) { + logger.debug("Failed to unlock cache file", false); + } + } + } + + /** Reads the cache file. */ + synchronized JsonNode readCacheFile() { try { if (!cacheFile.exists()) { logger.debug("Cache file doesn't exists. File: {}", cacheFile); @@ -215,7 +273,8 @@ JsonNode readCacheFile() { new InputStreamReader(new FileInputStream(cacheFile), DEFAULT_FILE_ENCODING)) { if (onlyOwnerPermissions) { - FileUtil.throwWhenPermiossionDifferentThanReadWriteForOwner(cacheFile, "Read cache"); + FileUtil.throwWhenFilePermissionsWiderThanUserOnly(cacheFile, "Read cache"); + FileUtil.throwWhenParentDirectoryPermissionsWiderThanUserOnly(cacheFile, "Read cache"); FileUtil.throwWhenOwnerDifferentThanCurrentUser(cacheFile, "Read cache"); } else { FileUtil.logFileUsage(cacheFile, "Read cache", false); @@ -228,15 +287,8 @@ JsonNode readCacheFile() { return null; } - void writeCacheFile(JsonNode input) { + synchronized void writeCacheFile(JsonNode input) { logger.debug("Writing cache file. File: {}", cacheFile); - if (cacheFile == null || !tryLockCacheFile()) { - // no cache file or it failed to lock file - logger.debug( - "No cache file exists or failed to lock the file. Skipping writing the cache", false); - return; - } - // NOTE: must unlock cache file try { if (input == null) { return; @@ -244,7 +296,9 @@ void writeCacheFile(JsonNode input) { try (Writer writer = new OutputStreamWriter(new FileOutputStream(cacheFile), DEFAULT_FILE_ENCODING)) { if (onlyOwnerPermissions) { - FileUtil.throwWhenPermiossionDifferentThanReadWriteForOwner(cacheFile, "Write to cache"); + FileUtil.throwWhenFilePermissionsWiderThanUserOnly(cacheFile, "Write to cache"); + FileUtil.throwWhenParentDirectoryPermissionsWiderThanUserOnly( + cacheFile, "Write to cache"); } else { FileUtil.logFileUsage(cacheFile, "Write to cache", false); } @@ -252,14 +306,10 @@ void writeCacheFile(JsonNode input) { } } catch (IOException ex) { logger.debug("Failed to write the cache file. File: {}", cacheFile); - } finally { - if (!unlockCacheFile()) { - logger.debug("Failed to unlock cache file", false); - } } } - void deleteCacheFile() { + synchronized void deleteCacheFile() { logger.debug("Deleting cache file. File: {}, lock file: {}", cacheFile, cacheLockFile); if (cacheFile == null) { @@ -277,68 +327,44 @@ void deleteCacheFile() { * * @return true if success or false */ - private boolean tryLockCacheFile() { + private synchronized boolean tryToLockCacheFile() { int cnt = 0; boolean locked = false; - while (cnt < 100 && !(locked = lockCacheFile())) { + while (cnt < 5 && !(locked = lockCacheFile())) { try { - Thread.sleep(100); + Thread.sleep(10); } catch (InterruptedException ex) { // doesn't matter } ++cnt; } if (!locked) { - logger.debug("Failed to lock the cache file.", false); + deleteCacheLockIfExpired(); + if (!lockCacheFile()) { + logger.debug("Failed to lock the cache file.", false); + } } return locked; } - /** - * Lock cache file by creating a lock directory - * - * @return true if success or false - */ - private boolean lockCacheFile() { - return cacheLockFile.mkdirs(); - } - - /** - * Unlock cache file by deleting a lock directory - * - * @return true if success or false - */ - private boolean unlockCacheFile() { - return cacheLockFile.delete(); - } - - private boolean checkCacheLockFile() { + private synchronized void deleteCacheLockIfExpired() { long currentTime = new Date().getTime(); - long cacheFileTs = fileCreationTime(cacheFile); - - if (!cacheLockFile.exists() - && cacheFileTs > 0 - && currentTime - this.cacheExpirationInMilliseconds <= cacheFileTs) { - logger.debug("No cache file lock directory exists and cache file is up to date.", false); - return true; - } - long lockFileTs = fileCreationTime(cacheLockFile); if (lockFileTs < 0) { - // failed to get the timestamp of lock directory - return false; - } - if (lockFileTs < currentTime - this.cacheFileLockExpirationInMilliseconds) { + logger.debug("Failed to get the timestamp of lock directory"); + } else if (lockFileTs < currentTime - this.cacheFileLockExpirationInMilliseconds) { // old lock file - if (!cacheLockFile.delete()) { - logger.debug("Failed to delete the directory. Dir: {}", cacheLockFile); - return false; + try { + if (!cacheLockFile.delete()) { + logger.debug("Failed to delete the directory. Dir: {}", cacheLockFile); + } else { + logger.debug("Deleted expired cache lock directory.", false); + } + } catch (Exception e) { + logger.debug( + "Failed to delete the directory. Dir: {}, Error: {}", cacheLockFile, e.getMessage()); } - logger.debug("Deleted the cache lock directory, because it was old.", false); - return currentTime - this.cacheExpirationInMilliseconds <= cacheFileTs; } - logger.debug("Failed to lock the file. Ignored.", false); - return false; } /** @@ -346,7 +372,7 @@ private boolean checkCacheLockFile() { * * @return epoch time in ms */ - private static long fileCreationTime(File targetFile) { + private static synchronized long fileCreationTime(File targetFile) { if (!targetFile.exists()) { logger.debug("File not exists. File: {}", targetFile); return -1; @@ -361,7 +387,21 @@ private static long fileCreationTime(File targetFile) { return -1; } - String getCacheFilePath() { - return cacheFile.getAbsolutePath(); + /** + * Lock cache file by creating a lock directory + * + * @return true if success or false + */ + private synchronized boolean lockCacheFile() { + return cacheLockFile.mkdirs(); + } + + /** + * Unlock cache file by deleting a lock directory + * + * @return true if success or false + */ + private synchronized boolean unlockCacheFile() { + return cacheLockFile.delete(); } } diff --git a/src/main/java/net/snowflake/client/core/FileUtil.java b/src/main/java/net/snowflake/client/core/FileUtil.java index 005b9ac18..a383b69ce 100644 --- a/src/main/java/net/snowflake/client/core/FileUtil.java +++ b/src/main/java/net/snowflake/client/core/FileUtil.java @@ -41,12 +41,30 @@ public static void logFileUsage(String stringPath, String context, boolean logRe logFileUsage(path, context, logReadAccess); } - public static void throwWhenPermiossionDifferentThanReadWriteForOwner(File file, String context) { - throwWhenPermiossionDifferentThanReadWriteForOwner(file.toPath(), context); + public static boolean isWritable(String path) { + File file = new File(path); + if (!file.canWrite()) { + logger.debug("File/directory not writeable: {}", path); + return false; + } + return true; + } + + public static void throwWhenParentDirectoryPermissionsWiderThanUserOnly( + File file, String context) { + throwWhenDirectoryPermissionsWiderThanUserOnly(file.getParentFile(), context); + } + + public static void throwWhenFilePermissionsWiderThanUserOnly(File file, String context) { + throwWhenPermissionsWiderThanUserOnly(file.toPath(), context, false); } - public static void throwWhenPermiossionDifferentThanReadWriteForOwner( - Path filePath, String context) { + public static void throwWhenDirectoryPermissionsWiderThanUserOnly(File file, String context) { + throwWhenPermissionsWiderThanUserOnly(file.toPath(), context, true); + } + + public static void throwWhenPermissionsWiderThanUserOnly( + Path filePath, String context, boolean isDirectory) { // we do not check the permissions for Windows if (isWindows()) { return; @@ -58,16 +76,28 @@ public static void throwWhenPermiossionDifferentThanReadWriteForOwner( boolean isReadableByOthers = isPermPresent(filePermissions, READ_BY_OTHERS); boolean isExecutable = isPermPresent(filePermissions, EXECUTABLE); - if (isWritableByOthers || isReadableByOthers || isExecutable) { + boolean permissionsTooOpen; + if (isDirectory) { + permissionsTooOpen = isWritableByOthers || isReadableByOthers; + } else { + permissionsTooOpen = isWritableByOthers || isReadableByOthers || isExecutable; + } + if (permissionsTooOpen) { logger.debug( - "{}File {} access rights: {}", getContextStr(context), filePath, filePermissions); + "{}File/directory {} access rights: {}", + getContextStr(context), + filePath, + filePermissions); throw new SecurityException( - String.format("Access to file %s is wider than allowed only to the owner", filePath)); + String.format( + "Access to file or directory %s is wider than allowed only to the owner. Remove cached file/directory and re-run the driver.", + filePath)); } } catch (IOException e) { throw new SecurityException( String.format( - "%s Unable to access the file to check the permissions. Error: %s", filePath, e)); + "%s Unable to access the file/directory to check the permissions. Error: %s", + filePath, e)); } } diff --git a/src/main/java/net/snowflake/client/core/HexUtil.java b/src/main/java/net/snowflake/client/core/HexUtil.java new file mode 100644 index 000000000..1200af0e6 --- /dev/null +++ b/src/main/java/net/snowflake/client/core/HexUtil.java @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2025 Snowflake Computing Inc. All rights reserved. + */ + +package net.snowflake.client.core; + +class HexUtil { + + /** + * Converts Byte array to hex string + * + * @param bytes a byte array + * @return a string in hexadecimal code + */ + static String byteToHexString(byte[] bytes) { + final char[] hexArray = "0123456789ABCDEF".toCharArray(); + char[] hexChars = new char[bytes.length * 2]; + for (int j = 0; j < bytes.length; j++) { + int v = bytes[j] & 0xFF; + hexChars[j * 2] = hexArray[v >>> 4]; + hexChars[j * 2 + 1] = hexArray[v & 0x0F]; + } + return new String(hexChars); + } +} diff --git a/src/main/java/net/snowflake/client/core/SFTrustManager.java b/src/main/java/net/snowflake/client/core/SFTrustManager.java index 6440ec6c3..ae7e82f2d 100644 --- a/src/main/java/net/snowflake/client/core/SFTrustManager.java +++ b/src/main/java/net/snowflake/client/core/SFTrustManager.java @@ -225,7 +225,6 @@ public class SFTrustManager extends X509ExtendedTrustManager { .setCacheDirectorySystemProperty(CACHE_DIR_PROP) .setCacheDirectoryEnvironmentVariable(CACHE_DIR_ENV) .setBaseCacheFileName(CACHE_FILE_NAME) - .setCacheExpirationInSeconds(CACHE_EXPIRATION_IN_SECONDS) .setCacheFileLockExpirationInSeconds(CACHE_FILE_LOCK_EXPIRATION_IN_SECONDS) .setOnlyOwnerPermissions(false) .build(); @@ -407,8 +406,8 @@ private static String encodeCacheKey(OcspResponseCacheKey ocsp_cache_key) { private static String CertificateIDToString(CertificateID certificateID) { return String.format( "CertID. NameHash: %s, KeyHash: %s, Serial Number: %s", - byteToHexString(certificateID.getIssuerNameHash()), - byteToHexString(certificateID.getIssuerKeyHash()), + HexUtil.byteToHexString(certificateID.getIssuerNameHash()), + HexUtil.byteToHexString(certificateID.getIssuerKeyHash()), MessageFormat.format("{0,number,#}", certificateID.getSerialNumber())); } @@ -538,23 +537,6 @@ private static void verifySignature( } } - /** - * Converts Byte array to hex string - * - * @param bytes a byte array - * @return a string in hexadecimal code - */ - private static String byteToHexString(byte[] bytes) { - final char[] hexArray = "0123456789ABCDEF".toCharArray(); - char[] hexChars = new char[bytes.length * 2]; - for (int j = 0; j < bytes.length; j++) { - int v = bytes[j] & 0xFF; - hexChars[j * 2] = hexArray[v >>> 4]; - hexChars[j * 2 + 1] = hexArray[v & 0x0F]; - } - return new String(hexChars); - } - /** * Gets HttpClient object * @@ -1610,7 +1592,9 @@ public boolean equals(Object obj) { public String toString() { return String.format( "OcspResponseCacheKey: NameHash: %s, KeyHash: %s, SerialNumber: %s", - byteToHexString(nameHash), byteToHexString(keyHash), serialNumber.toString()); + HexUtil.byteToHexString(nameHash), + HexUtil.byteToHexString(keyHash), + serialNumber.toString()); } } diff --git a/src/main/java/net/snowflake/client/core/SecureStorageAppleManager.java b/src/main/java/net/snowflake/client/core/SecureStorageAppleManager.java index fb467edcc..bfe7540bf 100644 --- a/src/main/java/net/snowflake/client/core/SecureStorageAppleManager.java +++ b/src/main/java/net/snowflake/client/core/SecureStorageAppleManager.java @@ -32,7 +32,7 @@ public SecureStorageStatus setCredential(String host, String user, String type, return SecureStorageStatus.SUCCESS; } - String target = SecureStorageManager.convertTarget(host, user, type); + String target = SecureStorageManager.buildCredentialsKey(host, user, type); byte[] targetBytes = target.getBytes(StandardCharsets.UTF_8); byte[] userBytes = user.toUpperCase().getBytes(StandardCharsets.UTF_8); byte[] credBytes = cred.getBytes(StandardCharsets.UTF_8); @@ -92,7 +92,7 @@ public SecureStorageStatus setCredential(String host, String user, String type, } public String getCredential(String host, String user, String type) { - String target = SecureStorageManager.convertTarget(host, user, type); + String target = SecureStorageManager.buildCredentialsKey(host, user, type); byte[] targetBytes = target.getBytes(StandardCharsets.UTF_8); byte[] userBytes = user.toUpperCase().getBytes(StandardCharsets.UTF_8); @@ -141,7 +141,7 @@ public String getCredential(String host, String user, String type) { } public SecureStorageStatus deleteCredential(String host, String user, String type) { - String target = SecureStorageManager.convertTarget(host, user, type); + String target = SecureStorageManager.buildCredentialsKey(host, user, type); byte[] targetBytes = target.getBytes(StandardCharsets.UTF_8); byte[] userBytes = user.toUpperCase().getBytes(StandardCharsets.UTF_8); diff --git a/src/main/java/net/snowflake/client/core/SecureStorageLinuxManager.java b/src/main/java/net/snowflake/client/core/SecureStorageLinuxManager.java index d8c44dda3..00a67f372 100644 --- a/src/main/java/net/snowflake/client/core/SecureStorageLinuxManager.java +++ b/src/main/java/net/snowflake/client/core/SecureStorageLinuxManager.java @@ -23,14 +23,12 @@ */ public class SecureStorageLinuxManager implements SecureStorageManager { private static final SFLogger logger = SFLoggerFactory.getLogger(SecureStorageLinuxManager.class); - private static final String CACHE_FILE_NAME = "temporary_credential.json"; + private static final String CACHE_FILE_NAME = "credential_cache_v1.json"; private static final String CACHE_DIR_PROP = "net.snowflake.jdbc.temporaryCredentialCacheDir"; private static final String CACHE_DIR_ENV = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"; - private static final long CACHE_EXPIRATION_IN_SECONDS = 86400L; + private static final String CACHE_FILE_TOKENS_OBJECT_NAME = "tokens"; private static final long CACHE_FILE_LOCK_EXPIRATION_IN_SECONDS = 60L; - private FileCacheManager fileCacheManager; - - private final Map> localCredCache = new HashMap<>(); + private final FileCacheManager fileCacheManager; private SecureStorageLinuxManager() { fileCacheManager = @@ -38,7 +36,6 @@ private SecureStorageLinuxManager() { .setCacheDirectorySystemProperty(CACHE_DIR_PROP) .setCacheDirectoryEnvironmentVariable(CACHE_DIR_ENV) .setBaseCacheFileName(CACHE_FILE_NAME) - .setCacheExpirationInSeconds(CACHE_EXPIRATION_IN_SECONDS) .setCacheFileLockExpirationInSeconds(CACHE_FILE_LOCK_EXPIRATION_IN_SECONDS) .build(); logger.debug( @@ -53,78 +50,89 @@ public static SecureStorageLinuxManager getInstance() { return SecureStorageLinuxManagerHolder.INSTANCE; } - private ObjectNode localCacheToJson() { - ObjectNode res = mapper.createObjectNode(); - for (Map.Entry> elem : localCredCache.entrySet()) { - String elemHost = elem.getKey(); - Map hostMap = elem.getValue(); - ObjectNode hostNode = mapper.createObjectNode(); - for (Map.Entry elem0 : hostMap.entrySet()) { - hostNode.put(elem0.getKey(), elem0.getValue()); - } - res.set(elemHost, hostNode); - } - return res; - } - + @Override public synchronized SecureStorageStatus setCredential( String host, String user, String type, String token) { if (Strings.isNullOrEmpty(token)) { logger.warn("No token provided", false); return SecureStorageStatus.SUCCESS; } - - localCredCache.computeIfAbsent(host.toUpperCase(), newMap -> new HashMap<>()); - - Map hostMap = localCredCache.get(host.toUpperCase()); - hostMap.put(SecureStorageManager.convertTarget(host, user, type), token); - - fileCacheManager.writeCacheFile(localCacheToJson()); + fileCacheManager.withLock( + () -> { + Map> cachedCredentials = + readJsonStoreCache(fileCacheManager.readCacheFile()); + cachedCredentials.computeIfAbsent( + CACHE_FILE_TOKENS_OBJECT_NAME, tokensMap -> new HashMap<>()); + Map credentialsMap = cachedCredentials.get(CACHE_FILE_TOKENS_OBJECT_NAME); + credentialsMap.put(SecureStorageManager.buildCredentialsKey(host, user, type), token); + fileCacheManager.writeCacheFile( + SecureStorageLinuxManager.this.localCacheToJson(cachedCredentials)); + return null; + }); return SecureStorageStatus.SUCCESS; } + @Override public synchronized String getCredential(String host, String user, String type) { - JsonNode res = fileCacheManager.readCacheFile(); - readJsonStoreCache(res); - - Map hostMap = localCredCache.get(host.toUpperCase()); - - if (hostMap == null) { - return null; - } - - return hostMap.get(SecureStorageManager.convertTarget(host, user, type)); + return fileCacheManager.withLock( + () -> { + JsonNode res = fileCacheManager.readCacheFile(); + Map> cache = readJsonStoreCache(res); + Map credentialsMap = cache.get(CACHE_FILE_TOKENS_OBJECT_NAME); + if (credentialsMap == null) { + return null; + } + return credentialsMap.get(SecureStorageManager.buildCredentialsKey(host, user, type)); + }); } - /** May delete credentials which doesn't belong to this process */ + @Override public synchronized SecureStorageStatus deleteCredential(String host, String user, String type) { - Map hostMap = localCredCache.get(host.toUpperCase()); - if (hostMap != null) { - hostMap.remove(SecureStorageManager.convertTarget(host, user, type)); - if (hostMap.isEmpty()) { - localCredCache.remove(host.toUpperCase()); + fileCacheManager.withLock( + () -> { + JsonNode res = fileCacheManager.readCacheFile(); + Map> cache = readJsonStoreCache(res); + Map credentialsMap = cache.get(CACHE_FILE_TOKENS_OBJECT_NAME); + if (credentialsMap != null) { + credentialsMap.remove(SecureStorageManager.buildCredentialsKey(host, user, type)); + if (credentialsMap.isEmpty()) { + cache.remove(CACHE_FILE_TOKENS_OBJECT_NAME); + } + } + fileCacheManager.writeCacheFile(localCacheToJson(cache)); + return null; + }); + return SecureStorageStatus.SUCCESS; + } + + private ObjectNode localCacheToJson(Map> cache) { + ObjectNode jsonNode = mapper.createObjectNode(); + Map tokensMap = cache.get(CACHE_FILE_TOKENS_OBJECT_NAME); + if (tokensMap != null) { + ObjectNode tokensNode = mapper.createObjectNode(); + for (Map.Entry credential : tokensMap.entrySet()) { + tokensNode.put(credential.getKey(), credential.getValue()); } + jsonNode.set(CACHE_FILE_TOKENS_OBJECT_NAME, tokensNode); } - fileCacheManager.writeCacheFile(localCacheToJson()); - return SecureStorageStatus.SUCCESS; + return jsonNode; } - private void readJsonStoreCache(JsonNode m) { - if (m == null || !m.getNodeType().equals(JsonNodeType.OBJECT)) { + private Map> readJsonStoreCache(JsonNode node) { + Map> cache = new HashMap<>(); + if (node == null || !node.getNodeType().equals(JsonNodeType.OBJECT)) { logger.debug("Invalid cache file format."); - return; + return cache; } - for (Iterator> itr = m.fields(); itr.hasNext(); ) { - Map.Entry hostMap = itr.next(); - String host = hostMap.getKey(); - if (!localCredCache.containsKey(host)) { - localCredCache.put(host, new HashMap<>()); - } - JsonNode userJsonNode = hostMap.getValue(); - for (Iterator> itr0 = userJsonNode.fields(); itr0.hasNext(); ) { - Map.Entry userMap = itr0.next(); - localCredCache.get(host).put(userMap.getKey(), userMap.getValue().asText()); + cache.put(CACHE_FILE_TOKENS_OBJECT_NAME, new HashMap<>()); + JsonNode credentialsNode = node.get(CACHE_FILE_TOKENS_OBJECT_NAME); + Map credentialsCache = cache.get(CACHE_FILE_TOKENS_OBJECT_NAME); + if (credentialsNode != null && node.getNodeType().equals(JsonNodeType.OBJECT)) { + for (Iterator> itr = credentialsNode.fields(); itr.hasNext(); ) { + Map.Entry credential = itr.next(); + credentialsCache.put(credential.getKey(), credential.getValue().asText()); } } + return cache; } } diff --git a/src/main/java/net/snowflake/client/core/SecureStorageManager.java b/src/main/java/net/snowflake/client/core/SecureStorageManager.java index d64c26c38..c7b45f30c 100644 --- a/src/main/java/net/snowflake/client/core/SecureStorageManager.java +++ b/src/main/java/net/snowflake/client/core/SecureStorageManager.java @@ -4,12 +4,14 @@ package net.snowflake.client.core; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; + /** * Interface for accessing Platform specific Local Secure Storage E.g. keychain on Mac credential * manager on Windows */ interface SecureStorageManager { - String DRIVER_NAME = "SNOWFLAKE-JDBC-DRIVER"; int COLON_CHAR_LENGTH = 1; SecureStorageStatus setCredential(String host, String user, String type, String token); @@ -18,24 +20,23 @@ interface SecureStorageManager { SecureStorageStatus deleteCredential(String host, String user, String type); - static String convertTarget(String host, String user, String type) { + static String buildCredentialsKey(String host, String user, String type) { StringBuilder target = - new StringBuilder( - host.length() - + user.length() - + DRIVER_NAME.length() - + type.length() - + 3 * COLON_CHAR_LENGTH); + new StringBuilder(host.length() + user.length() + type.length() + 3 * COLON_CHAR_LENGTH); target.append(host.toUpperCase()); target.append(":"); target.append(user.toUpperCase()); target.append(":"); - target.append(DRIVER_NAME); - target.append(":"); target.append(type.toUpperCase()); - return target.toString(); + try { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] hash = md.digest(target.toString().getBytes()); + return HexUtil.byteToHexString(hash); + } catch (NoSuchAlgorithmException e) { + throw new RuntimeException(e); + } } enum SecureStorageStatus { diff --git a/src/main/java/net/snowflake/client/core/SecureStorageWindowsManager.java b/src/main/java/net/snowflake/client/core/SecureStorageWindowsManager.java index e47a6d88b..0cf6e5ff5 100644 --- a/src/main/java/net/snowflake/client/core/SecureStorageWindowsManager.java +++ b/src/main/java/net/snowflake/client/core/SecureStorageWindowsManager.java @@ -47,7 +47,7 @@ public SecureStorageStatus setCredential(String host, String user, String type, Memory credBlobMem = new Memory(credBlob.length); credBlobMem.write(0, credBlob, 0, credBlob.length); - String target = SecureStorageManager.convertTarget(host, user, type); + String target = SecureStorageManager.buildCredentialsKey(host, user, type); SecureStorageWindowsCredential cred = new SecureStorageWindowsCredential(); cred.Type = SecureStorageWindowsCredentialType.CRED_TYPE_GENERIC.getType(); @@ -76,7 +76,7 @@ public SecureStorageStatus setCredential(String host, String user, String type, public String getCredential(String host, String user, String type) { PointerByReference pCredential = new PointerByReference(); - String target = SecureStorageManager.convertTarget(host, user, type); + String target = SecureStorageManager.buildCredentialsKey(host, user, type); try { boolean ret = false; @@ -127,7 +127,7 @@ public String getCredential(String host, String user, String type) { } public SecureStorageStatus deleteCredential(String host, String user, String type) { - String target = SecureStorageManager.convertTarget(host, user, type); + String target = SecureStorageManager.buildCredentialsKey(host, user, type); boolean ret = false; synchronized (advapi32Lib) { diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactory.java b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactory.java index 0872fcd74..c232717be 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactory.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactory.java @@ -119,6 +119,7 @@ private void assertContainsClientCredentials( AssertUtil.assertTrue( loginInput.getOauthLoginInput().getClientSecret() != null, String.format( - "passing oauthClientSecret is required for %s authentication", authenticatorType.name())); + "passing oauthClientSecret is required for %s authentication", + authenticatorType.name())); } } diff --git a/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java b/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java index fa9d3baf0..e0cd3eb1c 100644 --- a/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java +++ b/src/main/java/net/snowflake/client/core/auth/oauth/TokenResponseDTO.java @@ -24,7 +24,7 @@ public class TokenResponseDTO { @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) public TokenResponseDTO( - @JsonProperty("access_token") String accessToken, + @JsonProperty(value = "access_token", required = true) String accessToken, @JsonProperty("refresh_token") String refreshToken, @JsonProperty("token_type") String tokenType, @JsonProperty("scope") String scope, diff --git a/src/test/java/net/snowflake/client/core/FileCacheManagerTest.java b/src/test/java/net/snowflake/client/core/FileCacheManagerTest.java index 023c20d40..ab1fb676c 100644 --- a/src/test/java/net/snowflake/client/core/FileCacheManagerTest.java +++ b/src/test/java/net/snowflake/client/core/FileCacheManagerTest.java @@ -24,7 +24,9 @@ import net.snowflake.client.annotations.RunOnLinuxOrMac; import net.snowflake.client.category.TestTags; import net.snowflake.client.jdbc.BaseJDBCTest; +import net.snowflake.client.jdbc.SnowflakeUtil; import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Tag; @@ -38,10 +40,9 @@ @Tag(TestTags.CORE) class FileCacheManagerTest extends BaseJDBCTest { - private static final String CACHE_FILE_NAME = "temporary_credential.json"; + private static final String CACHE_FILE_NAME = "credential_cache_v1.json.json"; private static final String CACHE_DIR_PROP = "net.snowflake.jdbc.temporaryCredentialCacheDir"; private static final String CACHE_DIR_ENV = "SF_TEMPORARY_CREDENTIAL_CACHE_DIR"; - private static final long CACHE_EXPIRATION_IN_SECONDS = 86400L; private static final long CACHE_FILE_LOCK_EXPIRATION_IN_SECONDS = 60L; private FileCacheManager fileCacheManager; @@ -54,7 +55,6 @@ public void setup() throws IOException { .setCacheDirectorySystemProperty(CACHE_DIR_PROP) .setCacheDirectoryEnvironmentVariable(CACHE_DIR_ENV) .setBaseCacheFileName(CACHE_FILE_NAME) - .setCacheExpirationInSeconds(CACHE_EXPIRATION_IN_SECONDS) .setCacheFileLockExpirationInSeconds(CACHE_FILE_LOCK_EXPIRATION_IN_SECONDS) .build(); cacheFile = createCacheFile(); @@ -65,34 +65,41 @@ public void clean() throws IOException { if (Files.exists(cacheFile.toPath())) { Files.delete(cacheFile.toPath()); } + if (Files.exists(cacheFile.getParentFile().toPath())) { + Files.delete(cacheFile.getParentFile().toPath()); + } } @ParameterizedTest @CsvSource({ - "rwx------,false", - "rw-------,true", - "r-x------,false", - "r--------,true", - "rwxrwx---,false", - "rwxrw----,false", - "rwxr-x---,false", - "rwxr-----,false", - "rwx-wx---,false", - "rwx-w----,false", - "rwx--x---,false", - "rwx---rwx,false", - "rwx---rw-,false", - "rwx---r-x,false", - "rwx---r--,false", - "rwx----wx,false", - "rwx----w-,false", - "rwx-----x,false" + "rwx------,rwx------,false", + "rw-------,rwx------,true", + "rw-------,rwx--xrwx,false", + "r-x------,rwx------,false", + "r--------,rwx------,true", + "rwxrwx---,rwx------,false", + "rwxrw----,rwx------,false", + "rwxr-x---,rwx------,false", + "rwxr-----,rwx------,false", + "rwx-wx---,rwx------,false", + "rwx-w----,rwx------,false", + "rwx--x---,rwx------,false", + "rwx---rwx,rwx------,false", + "rwx---rw-,rwx------,false", + "rwx---r-x,rwx------,false", + "rwx---r--,rwx------,false", + "rwx----wx,rwx------,false", + "rwx----w-,rwx------,false", + "rwx-----x,rwx------,false" }) @RunOnLinuxOrMac public void throwWhenReadCacheFileWithPermissionDifferentThanReadWriteForUserTest( - String permission, boolean isSucceed) throws IOException { + String permission, String parentDirectoryPermissions, boolean isSucceed) throws IOException { fileCacheManager.overrideCacheFile(cacheFile); Files.setPosixFilePermissions(cacheFile.toPath(), PosixFilePermissions.fromString(permission)); + Files.setPosixFilePermissions( + cacheFile.getParentFile().toPath(), + PosixFilePermissions.fromString(parentDirectoryPermissions)); if (isSucceed) { assertDoesNotThrow(() -> fileCacheManager.readCacheFile()); } else { @@ -134,17 +141,126 @@ public void throwWhenOverrideCacheFileNotFound() { assertTrue( ex.getMessage() .contains( - "Unable to access the file to check the permissions. Error: java.nio.file.NoSuchFileException:")); + "Unable to access the file/directory to check the permissions. Error: java.nio.file.NoSuchFileException:")); + } + + @Test + public void shouldCreateCacheDirForLinuxXDG() { + try (MockedStatic constantsMockedStatic = Mockito.mockStatic(Constants.class)) { + constantsMockedStatic.when(Constants::getOS).thenReturn(Constants.OS.LINUX); + try (MockedStatic snowflakeUtilMockedStatic = + Mockito.mockStatic(SnowflakeUtil.class)) { + snowflakeUtilMockedStatic + .when(() -> SnowflakeUtil.systemGetEnv("XDG_CACHE_HOME")) + .thenReturn("/XDG/Cache/"); + try (MockedStatic fileUtilMockedStatic = Mockito.mockStatic(FileUtil.class)) { + fileUtilMockedStatic.when(() -> FileUtil.isWritable("/XDG/Cache/")).thenReturn(true); + File defaultCacheDir = FileCacheManager.getDefaultCacheDir(); + Assertions.assertNotNull(defaultCacheDir); + Assertions.assertEquals("/XDG/Cache/snowflake", defaultCacheDir.getAbsolutePath()); + } + } + } + } + + @Test + public void shouldCreateCacheDirForLinuxWithoutXDG() { + try (MockedStatic constantsMockedStatic = Mockito.mockStatic(Constants.class)) { + constantsMockedStatic.when(Constants::getOS).thenReturn(Constants.OS.LINUX); + try (MockedStatic snowflakeUtilMockedStatic = + Mockito.mockStatic(SnowflakeUtil.class)) { + snowflakeUtilMockedStatic + .when(() -> SnowflakeUtil.systemGetEnv("XDG_CACHE_HOME")) + .thenReturn(null); + snowflakeUtilMockedStatic + .when(() -> SnowflakeUtil.systemGetProperty("user.home")) + .thenReturn("/User/Home"); + try (MockedStatic fileUtilMockedStatic = Mockito.mockStatic(FileUtil.class)) { + fileUtilMockedStatic.when(() -> FileUtil.isWritable("/User/Home")).thenReturn(true); + File defaultCacheDir = FileCacheManager.getDefaultCacheDir(); + Assertions.assertNotNull(defaultCacheDir); + Assertions.assertEquals("/User/Home/.cache/snowflake", defaultCacheDir.getAbsolutePath()); + } + } + } + } + + @Test + public void shouldCreateCacheDirForWindows() { + try (MockedStatic constantsMockedStatic = Mockito.mockStatic(Constants.class)) { + constantsMockedStatic.when(Constants::getOS).thenReturn(Constants.OS.WINDOWS); + try (MockedStatic snowflakeUtilMockedStatic = + Mockito.mockStatic(SnowflakeUtil.class)) { + snowflakeUtilMockedStatic + .when(() -> SnowflakeUtil.systemGetProperty("user.home")) + .thenReturn("/User/Home"); + try (MockedStatic fileUtilMockedStatic = Mockito.mockStatic(FileUtil.class)) { + fileUtilMockedStatic.when(() -> FileUtil.isWritable("/User/Home")).thenReturn(true); + File defaultCacheDir = FileCacheManager.getDefaultCacheDir(); + Assertions.assertNotNull(defaultCacheDir); + Assertions.assertEquals( + "/User/Home/AppData/Local/Snowflake/Caches", defaultCacheDir.getAbsolutePath()); + } + } + } + } + + @Test + public void shouldCreateCacheDirForMacOS() { + try (MockedStatic constantsMockedStatic = Mockito.mockStatic(Constants.class)) { + constantsMockedStatic.when(Constants::getOS).thenReturn(Constants.OS.MAC); + try (MockedStatic snowflakeUtilMockedStatic = + Mockito.mockStatic(SnowflakeUtil.class)) { + snowflakeUtilMockedStatic + .when(() -> SnowflakeUtil.systemGetProperty("user.home")) + .thenReturn("/User/Home"); + try (MockedStatic fileUtilMockedStatic = Mockito.mockStatic(FileUtil.class)) { + fileUtilMockedStatic.when(() -> FileUtil.isWritable("/User/Home")).thenReturn(true); + File defaultCacheDir = FileCacheManager.getDefaultCacheDir(); + Assertions.assertNotNull(defaultCacheDir); + Assertions.assertEquals( + "/User/Home/Library/Caches/Snowflake", defaultCacheDir.getAbsolutePath()); + } + } + } + } + + @Test + public void shouldReturnNullWhenNoHomeDirSet() { + try (MockedStatic constantsMockedStatic = Mockito.mockStatic(Constants.class)) { + constantsMockedStatic.when(Constants::getOS).thenReturn(Constants.OS.LINUX); + try (MockedStatic snowflakeUtilMockedStatic = + Mockito.mockStatic(SnowflakeUtil.class)) { + snowflakeUtilMockedStatic + .when(() -> SnowflakeUtil.systemGetEnv("XDG_CACHE_HOME")) + .thenReturn(null); + snowflakeUtilMockedStatic + .when(() -> SnowflakeUtil.systemGetProperty("user.home")) + .thenReturn(null); + File defaultCacheDir = FileCacheManager.getDefaultCacheDir(); + Assertions.assertNull(defaultCacheDir); + } + } } private File createCacheFile() { Path cacheFile = - Paths.get(systemGetProperty("user.home"), ".cache", "snowflake2", CACHE_FILE_NAME); + Paths.get(systemGetProperty("user.home"), ".cache", "snowflake_cache", CACHE_FILE_NAME); try { if (Files.exists(cacheFile)) { Files.delete(cacheFile); } - Files.createDirectories(cacheFile.getParent()); + if (Files.exists(cacheFile.getParent())) { + Files.delete(cacheFile.getParent()); + } + Files.createDirectories( + cacheFile.getParent(), + PosixFilePermissions.asFileAttribute( + Stream.of( + PosixFilePermission.OWNER_READ, + PosixFilePermission.OWNER_WRITE, + PosixFilePermission.OWNER_EXECUTE) + .collect(Collectors.toSet()))); Files.createFile( cacheFile, PosixFilePermissions.asFileAttribute( diff --git a/src/test/java/net/snowflake/client/core/SecureStorageManagerTest.java b/src/test/java/net/snowflake/client/core/SecureStorageManagerTest.java index b79875038..c135cb6a7 100644 --- a/src/test/java/net/snowflake/client/core/SecureStorageManagerTest.java +++ b/src/test/java/net/snowflake/client/core/SecureStorageManagerTest.java @@ -4,6 +4,7 @@ package net.snowflake.client.core; +import static net.snowflake.client.jdbc.SnowflakeUtil.systemGetProperty; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.notNullValue; @@ -13,6 +14,7 @@ import com.sun.jna.Memory; import com.sun.jna.Pointer; import com.sun.jna.ptr.PointerByReference; +import java.nio.file.Paths; import java.util.HashMap; import java.util.Iterator; import java.util.Map; @@ -20,7 +22,11 @@ import net.snowflake.client.annotations.RunOnMac; import net.snowflake.client.annotations.RunOnWindows; import net.snowflake.client.annotations.RunOnWindowsOrMac; +import net.snowflake.client.jdbc.SnowflakeUtil; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; class MockAdvapi32Lib implements SecureStorageWindowsManager.Advapi32Lib { @Override @@ -224,6 +230,34 @@ public class SecureStorageManagerTest { private static final String ID_TOKEN = "ID_TOKEN"; private static final String MFA_TOKEN = "MFATOKEN"; + @Test + public void testBuildCredentialsKey() { + // hex values obtained using https://emn178.github.io/online-tools/sha256.html + String hashedKey = + SecureStorageManager.buildCredentialsKey( + host, user, CachedCredentialType.OAUTH_ACCESS_TOKEN.getValue()); + Assertions.assertEquals( + "A7C7EBB89312E88552CD00664A0E20929801FACFBD682BF7C2363FB6EC8F914E", hashedKey); + + hashedKey = + SecureStorageManager.buildCredentialsKey( + host, user, CachedCredentialType.OAUTH_REFRESH_TOKEN.getValue()); + Assertions.assertEquals( + "DB37028833FA02B125FBD6DE8CE679C7E62E7D38FAC585E98060E00987F96772", hashedKey); + + hashedKey = + SecureStorageManager.buildCredentialsKey( + host, user, CachedCredentialType.ID_TOKEN.getValue()); + Assertions.assertEquals( + "6AA3F783E07D1D2182DAB59442806E2433C55C2BD4D9240790FD5B4B91FD4FDB", hashedKey); + + hashedKey = + SecureStorageManager.buildCredentialsKey( + host, user, CachedCredentialType.MFA_TOKEN.getValue()); + Assertions.assertEquals( + "9D10D4EFE45605D85993C6AC95334F1B63D36611B83615656EC7F277A947BF4B", hashedKey); + } + @Test @RunOnWindowsOrMac public void testLoadNativeLibrary() { @@ -260,10 +294,22 @@ public void testMacManager() { @Test @RunOnLinux public void testLinuxManager() { - SecureStorageManager manager = SecureStorageLinuxManager.getInstance(); - - testBody(manager); - testDeleteLinux(manager); + String cacheDirectory = + Paths.get(systemGetProperty("user.home"), ".cache", "snowflake_test_cache") + .toAbsolutePath() + .toString(); + try (MockedStatic snowflakeUtilMockedStatic = + Mockito.mockStatic(SnowflakeUtil.class)) { + snowflakeUtilMockedStatic + .when( + () -> + SnowflakeUtil.systemGetProperty("net.snowflake.jdbc.temporaryCredentialCacheDir")) + .thenReturn(cacheDirectory); + SecureStorageManager manager = SecureStorageLinuxManager.getInstance(); + + testBody(manager); + testDeleteLinux(manager); + } } private void testBody(SecureStorageManager manager) { diff --git a/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java b/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java index 88451aa0e..97d8d8258 100644 --- a/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java +++ b/src/test/java/net/snowflake/client/core/SnowflakeMFACacheTest.java @@ -4,6 +4,7 @@ package net.snowflake.client.core; +import static net.snowflake.client.jdbc.SnowflakeUtil.systemGetProperty; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -19,6 +20,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.StringWriter; +import java.nio.file.Paths; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; @@ -28,6 +30,7 @@ import net.snowflake.client.AbstractDriverIT; import net.snowflake.client.jdbc.SnowflakeBasicDataSource; import net.snowflake.client.jdbc.SnowflakeSQLException; +import net.snowflake.client.jdbc.SnowflakeUtil; import org.apache.commons.io.IOUtils; import org.apache.http.client.methods.HttpPost; import org.junit.jupiter.api.Disabled; @@ -100,131 +103,152 @@ private Properties getBaseProp() { @Test public void testMFAFunctionality() throws SQLException { - SessionUtil.deleteMfaTokenCache(host, user); - try (MockedStatic mockedHttpUtil = Mockito.mockStatic(HttpUtil.class)) { - mockedHttpUtil - .when( - () -> - HttpUtil.executeGeneralRequest( - any(HttpPost.class), - anyInt(), - anyInt(), - anyInt(), - anyInt(), - any(HttpClientSettingsKey.class))) - .thenAnswer( - new Answer() { - int callCount = 0; - - @Override - public String answer(InvocationOnMock invocation) throws Throwable { - String res; - JsonNode jsonNode; - final Object[] args = invocation.getArguments(); - - if (callCount == 0) { - // First connection request - jsonNode = parseRequest((HttpPost) args[0]); - assertTrue( - jsonNode - .path("data") - .path("SESSION_PARAMETERS") - .path("CLIENT_REQUEST_MFA_TOKEN") - .asBoolean()); - // return first mfa token - res = getNormalMockedHttpResponse(true, 0).toString(); - } else if (callCount == 1) { - // First close() request - res = getNormalMockedHttpResponse(true, -1).toString(); - } else if (callCount == 2) { - // Second connection request - jsonNode = parseRequest((HttpPost) args[0]); - assertTrue( - jsonNode - .path("data") - .path("SESSION_PARAMETERS") - .path("CLIENT_REQUEST_MFA_TOKEN") - .asBoolean()); - assertEquals(mockedMfaToken[0], jsonNode.path("data").path("TOKEN").asText()); - // Normally backend won't send a new mfa token in this case. For testing - // purpose, we issue a new token to test whether the mfa token can be refreshed - // when receiving a new one from server. - res = getNormalMockedHttpResponse(true, 1).toString(); - } else if (callCount == 3) { - // Second close() request - res = getNormalMockedHttpResponse(true, -1).toString(); - } else if (callCount == 4) { - // Third connection request - // Check for the new mfa token - jsonNode = parseRequest((HttpPost) args[0]); - assertTrue( - jsonNode - .path("data") - .path("SESSION_PARAMETERS") - .path("CLIENT_REQUEST_MFA_TOKEN") - .asBoolean()); - assertEquals(mockedMfaToken[1], jsonNode.path("data").path("TOKEN").asText()); - res = getNormalMockedHttpResponse(true, -1).toString(); - } else if (callCount == 5) { - // Third close() request - res = getNormalMockedHttpResponse(true, -1).toString(); - } else if (callCount == 6) { - // test if failed log in response can delete the cached mfa token - res = getNormalMockedHttpResponse(false, -1).toString(); - } else if (callCount == 7) { - jsonNode = parseRequest((HttpPost) args[0]); - assertTrue( - jsonNode - .path("data") - .path("SESSION_PARAMETERS") - .path("CLIENT_REQUEST_MFA_TOKEN") - .asBoolean()); - // no token should be included this time. - assertEquals("", jsonNode.path("data").path("TOKEN").asText()); - res = getNormalMockedHttpResponse(true, -1).toString(); - } else if (callCount == 8) { - // final close() - res = getNormalMockedHttpResponse(true, -1).toString(); - } else { - // unexpected request - res = getNormalMockedHttpResponse(false, -1).toString(); - } - - callCount += 1; // this will be incremented on both connecting and closing - return res; - } - }); - - Properties prop = getBaseProp(); - - // connect url - String url = "jdbc:snowflake://testaccount.snowflakecomputing.com"; - - // The first connection contains no mfa token. After the connection, a mfa token will be saved - Connection con = DriverManager.getConnection(url, prop); - con.close(); - - // The second connection is expected to include the mfa token issued for the first connection - // and a new mfa token is issued - Connection con1 = DriverManager.getConnection(url, prop); - con1.close(); - - // The third connection is expected to include the new mfa token. - Connection con2 = DriverManager.getConnection(url, prop); - con2.close(); - - // This connection would receive an exception and then should clean up the mfa cache - try { - Connection con3 = DriverManager.getConnection(url, prop); - fail(); - } catch (SnowflakeSQLException ex) { - // An exception is forced to happen by mocking. Do nothing. + try (MockedStatic constantsMockedStatic = Mockito.mockStatic(Constants.class)) { + constantsMockedStatic.when(Constants::getOS).thenReturn(Constants.OS.LINUX); + String cacheDirectory = + Paths.get(systemGetProperty("user.home"), ".cache", "snowflake_test_cache") + .toAbsolutePath() + .toString(); + try (MockedStatic snowflakeUtilMockedStatic = + Mockito.mockStatic(SnowflakeUtil.class)) { + snowflakeUtilMockedStatic + .when( + () -> + SnowflakeUtil.systemGetProperty( + "net.snowflake.jdbc.temporaryCredentialCacheDir")) + .thenReturn(cacheDirectory); + SessionUtil.deleteMfaTokenCache(host, user); + try (MockedStatic mockedHttpUtil = Mockito.mockStatic(HttpUtil.class)) { + mockedHttpUtil + .when( + () -> + HttpUtil.executeGeneralRequest( + any(HttpPost.class), + anyInt(), + anyInt(), + anyInt(), + anyInt(), + any(HttpClientSettingsKey.class))) + .thenAnswer( + new Answer() { + int callCount = 0; + + @Override + public String answer(InvocationOnMock invocation) throws Throwable { + String res; + JsonNode jsonNode; + final Object[] args = invocation.getArguments(); + + if (callCount == 0) { + // First connection request + jsonNode = parseRequest((HttpPost) args[0]); + assertTrue( + jsonNode + .path("data") + .path("SESSION_PARAMETERS") + .path("CLIENT_REQUEST_MFA_TOKEN") + .asBoolean()); + // return first mfa token + res = getNormalMockedHttpResponse(true, 0).toString(); + } else if (callCount == 1) { + // First close() request + res = getNormalMockedHttpResponse(true, -1).toString(); + } else if (callCount == 2) { + // Second connection request + jsonNode = parseRequest((HttpPost) args[0]); + assertTrue( + jsonNode + .path("data") + .path("SESSION_PARAMETERS") + .path("CLIENT_REQUEST_MFA_TOKEN") + .asBoolean()); + assertEquals( + mockedMfaToken[0], jsonNode.path("data").path("TOKEN").asText()); + // Normally backend won't send a new mfa token in this case. For testing + // purpose, we issue a new token to test whether the mfa token can be + // refreshed + // when receiving a new one from server. + res = getNormalMockedHttpResponse(true, 1).toString(); + } else if (callCount == 3) { + // Second close() request + res = getNormalMockedHttpResponse(true, -1).toString(); + } else if (callCount == 4) { + // Third connection request + // Check for the new mfa token + jsonNode = parseRequest((HttpPost) args[0]); + assertTrue( + jsonNode + .path("data") + .path("SESSION_PARAMETERS") + .path("CLIENT_REQUEST_MFA_TOKEN") + .asBoolean()); + assertEquals( + mockedMfaToken[1], jsonNode.path("data").path("TOKEN").asText()); + res = getNormalMockedHttpResponse(true, -1).toString(); + } else if (callCount == 5) { + // Third close() request + res = getNormalMockedHttpResponse(true, -1).toString(); + } else if (callCount == 6) { + // test if failed log in response can delete the cached mfa token + res = getNormalMockedHttpResponse(false, -1).toString(); + } else if (callCount == 7) { + jsonNode = parseRequest((HttpPost) args[0]); + assertTrue( + jsonNode + .path("data") + .path("SESSION_PARAMETERS") + .path("CLIENT_REQUEST_MFA_TOKEN") + .asBoolean()); + // no token should be included this time. + assertEquals("", jsonNode.path("data").path("TOKEN").asText()); + res = getNormalMockedHttpResponse(true, -1).toString(); + } else if (callCount == 8) { + // final close() + res = getNormalMockedHttpResponse(true, -1).toString(); + } else { + // unexpected request + res = getNormalMockedHttpResponse(false, -1).toString(); + } + + callCount += 1; // this will be incremented on both connecting and closing + return res; + } + }); + + Properties prop = getBaseProp(); + + // connect url + String url = "jdbc:snowflake://testaccount.snowflakecomputing.com"; + + // The first connection contains no mfa token. After the connection, a mfa token will be + // saved + Connection con = DriverManager.getConnection(url, prop); + con.close(); + + // The second connection is expected to include the mfa token issued for the first + // connection + // and a new mfa token is issued + Connection con1 = DriverManager.getConnection(url, prop); + con1.close(); + + // The third connection is expected to include the new mfa token. + Connection con2 = DriverManager.getConnection(url, prop); + con2.close(); + + // This connection would receive an exception and then should clean up the mfa cache + try { + Connection con3 = DriverManager.getConnection(url, prop); + fail(); + } catch (SnowflakeSQLException ex) { + // An exception is forced to happen by mocking. Do nothing. + } + // This connect request should not contain mfa cached token + Connection con4 = DriverManager.getConnection(url, prop); + con4.close(); + } + SessionUtil.deleteMfaTokenCache(host, user); } - // This connect request should not contain mfa cached token - Connection con4 = DriverManager.getConnection(url, prop); - con4.close(); } - SessionUtil.deleteMfaTokenCache(host, user); } class MockUnavailableAdvapi32Lib implements SecureStorageWindowsManager.Advapi32Lib { diff --git a/src/test/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactoryTest.java b/src/test/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactoryTest.java index 0a1c74d9a..840f0e3d1 100644 --- a/src/test/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactoryTest.java +++ b/src/test/java/net/snowflake/client/core/auth/oauth/OAuthAccessTokenProviderFactoryTest.java @@ -84,7 +84,8 @@ public void shouldFailToCreateClientCredentialsAccessTokenProviderWithoutClientI AuthenticatorType.OAUTH_CLIENT_CREDENTIALS, loginInput)); Assertions.assertTrue( e.getMessage() - .contains("passing oauthClientId is required for OAUTH_CLIENT_CREDENTIALS authentication.")); + .contains( + "passing oauthClientId is required for OAUTH_CLIENT_CREDENTIALS authentication.")); } @Test @@ -138,7 +139,8 @@ public void shouldFailToCreateAuthzCodeAccessTokenProviderWithoutClientId() { AuthenticatorType.OAUTH_AUTHORIZATION_CODE, loginInput)); Assertions.assertTrue( e.getMessage() - .contains("passing oauthClientId is required for OAUTH_AUTHORIZATION_CODE authentication.")); + .contains( + "passing oauthClientId is required for OAUTH_AUTHORIZATION_CODE authentication.")); } @Test diff --git a/src/test/java/net/snowflake/client/jdbc/SSOConnectionTest.java b/src/test/java/net/snowflake/client/jdbc/SSOConnectionTest.java index 71618c9e9..86a77eee0 100644 --- a/src/test/java/net/snowflake/client/jdbc/SSOConnectionTest.java +++ b/src/test/java/net/snowflake/client/jdbc/SSOConnectionTest.java @@ -4,6 +4,7 @@ package net.snowflake.client.jdbc; +import static net.snowflake.client.jdbc.SnowflakeUtil.systemGetProperty; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -25,10 +26,12 @@ import java.net.Socket; import java.net.URI; import java.nio.charset.StandardCharsets; +import java.nio.file.Paths; import java.sql.Connection; import java.sql.DriverManager; import java.util.ArrayList; import java.util.Properties; +import net.snowflake.client.core.Constants; import net.snowflake.client.core.HttpClientSettingsKey; import net.snowflake.client.core.HttpUtil; import net.snowflake.client.core.SFException; @@ -306,38 +309,53 @@ private SFLoginInput initMockLoginInput() { @Test public void testIdTokenInSSO() throws Throwable { - try (MockedStatic mockedHttpUtil = mockStatic(HttpUtil.class); - MockedStatic mockedSessionUtilExternalBrowser = - mockStatic(SessionUtilExternalBrowser.class)) { - - initMock(mockedHttpUtil, mockedSessionUtilExternalBrowser); - SessionUtil.deleteIdTokenCache("testaccount.snowflakecomputing.com", "testuser"); - - Properties properties = new Properties(); - properties.put("user", "testuser"); - properties.put("password", "testpassword"); - properties.put("account", "testaccount"); - properties.put("insecureMode", true); - properties.put("authenticator", "externalbrowser"); - properties.put("CLIENT_STORE_TEMPORARY_CREDENTIAL", true); - - // connect url - String url = "jdbc:snowflake://testaccount.snowflakecomputing.com"; - - // initial connection getting id token and storing in the cache file. - Connection con = DriverManager.getConnection(url, properties); - SnowflakeConnectionV1 sfcon = (SnowflakeConnectionV1) con; - assertThat("token", sfcon.getSfSession().getSessionToken(), equalTo(MOCK_SESSION_TOKEN)); - assertThat("idToken", sfcon.getSfSession().getIdToken(), equalTo(MOCK_ID_TOKEN)); - - // second connection reads the cache and use the id token to get the - // session token. - Connection conSecond = DriverManager.getConnection(url, properties); - SnowflakeConnectionV1 sfconSecond = (SnowflakeConnectionV1) conSecond; - assertThat( - "token", sfconSecond.getSfSession().getSessionToken(), equalTo(MOCK_NEW_SESSION_TOKEN)); - // we won't get a new id_token here - assertThat("idToken", sfcon.getSfSession().getIdToken(), equalTo(MOCK_ID_TOKEN)); + try (MockedStatic constantsMockedStatic = Mockito.mockStatic(Constants.class)) { + constantsMockedStatic.when(Constants::getOS).thenReturn(Constants.OS.LINUX); + String cacheDirectory = + Paths.get(systemGetProperty("user.home"), ".cache", "snowflake_test_cache") + .toAbsolutePath() + .toString(); + try (MockedStatic snowflakeUtilMockedStatic = + Mockito.mockStatic(SnowflakeUtil.class)) { + snowflakeUtilMockedStatic + .when(() -> systemGetProperty("net.snowflake.jdbc.temporaryCredentialCacheDir")) + .thenReturn(cacheDirectory); + try (MockedStatic mockedHttpUtil = mockStatic(HttpUtil.class); + MockedStatic mockedSessionUtilExternalBrowser = + mockStatic(SessionUtilExternalBrowser.class)) { + + initMock(mockedHttpUtil, mockedSessionUtilExternalBrowser); + SessionUtil.deleteIdTokenCache("testaccount.snowflakecomputing.com", "testuser"); + + Properties properties = new Properties(); + properties.put("user", "testuser"); + properties.put("password", "testpassword"); + properties.put("account", "testaccount"); + properties.put("insecureMode", true); + properties.put("authenticator", "externalbrowser"); + properties.put("CLIENT_STORE_TEMPORARY_CREDENTIAL", true); + + // connect url + String url = "jdbc:snowflake://testaccount.snowflakecomputing.com"; + + // initial connection getting id token and storing in the cache file. + Connection con = DriverManager.getConnection(url, properties); + SnowflakeConnectionV1 sfcon = (SnowflakeConnectionV1) con; + assertThat("token", sfcon.getSfSession().getSessionToken(), equalTo(MOCK_SESSION_TOKEN)); + assertThat("idToken", sfcon.getSfSession().getIdToken(), equalTo(MOCK_ID_TOKEN)); + + // second connection reads the cache and use the id token to get the + // session token. + Connection conSecond = DriverManager.getConnection(url, properties); + SnowflakeConnectionV1 sfconSecond = (SnowflakeConnectionV1) conSecond; + assertThat( + "token", + sfconSecond.getSfSession().getSessionToken(), + equalTo(MOCK_NEW_SESSION_TOKEN)); + // we won't get a new id_token here + assertThat("idToken", sfcon.getSfSession().getIdToken(), equalTo(MOCK_ID_TOKEN)); + } + } } } }