diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java index faa74ce86..67faec889 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/EncryptionProvider.java @@ -41,8 +41,8 @@ public class EncryptionProvider { private static final String FILE_CIPHER = "AES/CBC/PKCS5Padding"; private static final String KEY_CIPHER = "AES/ECB/PKCS5Padding"; private static final int BUFFER_SIZE = 2 * 1024 * 1024; // 2 MB - private static ThreadLocal secRnd = - new ThreadLocal<>().withInitial(SecureRandom::new); + private static final ThreadLocal SEC_RND = + ThreadLocal.withInitial(SecureRandom::new); /** * Decrypt a InputStream @@ -70,7 +70,7 @@ public static InputStream decryptStream( byte[] kekBytes = Base64.getDecoder().decode(encMat.getQueryStageMasterKey()); byte[] keyBytes = Base64.getDecoder().decode(keyBase64); byte[] ivBytes = Base64.getDecoder().decode(ivBase64); - SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); + SecretKey kek = new SecretKeySpec(kekBytes, AES); Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); keyCipher.init(Cipher.DECRYPT_MODE, kek); byte[] fileKeyBytes = keyCipher.doFinal(keyBytes); @@ -98,7 +98,7 @@ public static void decrypt( // Decrypt file key { final Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); - SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); + SecretKey kek = new SecretKeySpec(kekBytes, AES); keyCipher.init(Cipher.DECRYPT_MODE, kek); byte[] fileKeyBytes = keyCipher.doFinal(keyBytes); @@ -166,11 +166,11 @@ public static CipherInputStream encrypt( // Create IV ivData = new byte[blockSize]; - secRnd.get().nextBytes(ivData); + SEC_RND.get().nextBytes(ivData); final IvParameterSpec iv = new IvParameterSpec(ivData); // Create file key - secRnd.get().nextBytes(fileKeyBytes); + SEC_RND.get().nextBytes(fileKeyBytes); SecretKey fileKey = new SecretKeySpec(fileKeyBytes, 0, keySize, AES); // Init cipher diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java index b4a4682d8..6859b609b 100644 --- a/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java @@ -83,7 +83,7 @@ private static void initRandomIvsAndFileKey( private static byte[] encryptKey(byte[] kekBytes, byte[] keyBytes, byte[] keyIvData, byte[] aad) throws InvalidKeyException, InvalidAlgorithmParameterException, IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException { - SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); + SecretKey kek = new SecretKeySpec(kekBytes, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, keyIvData); Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); keyCipher.init(Cipher.ENCRYPT_MODE, kek, gcmParameterSpec); @@ -97,7 +97,7 @@ private static CipherInputStream encryptContent( InputStream src, byte[] keyBytes, byte[] dataIvBytes, byte[] aad) throws InvalidKeyException, InvalidAlgorithmParameterException, NoSuchPaddingException, NoSuchAlgorithmException { - SecretKey fileKey = new SecretKeySpec(keyBytes, 0, keyBytes.length, AES); + SecretKey fileKey = new SecretKeySpec(keyBytes, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, dataIvBytes); Cipher fileCipher = Cipher.getInstance(FILE_CIPHER); fileCipher.init(Cipher.ENCRYPT_MODE, fileKey, gcmParameterSpec); @@ -213,7 +213,7 @@ private static void decryptContentFromFile( private static byte[] decryptKey(byte[] kekBytes, byte[] ivBytes, byte[] keyBytes, byte[] aad) throws InvalidKeyException, InvalidAlgorithmParameterException, IllegalBlockSizeException, BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException { - SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES); + SecretKey kek = new SecretKeySpec(kekBytes, AES); GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes); Cipher keyCipher = Cipher.getInstance(KEY_CIPHER); keyCipher.init(Cipher.DECRYPT_MODE, kek, gcmParameterSpec); diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java new file mode 100644 index 000000000..1591780e4 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Aead.java @@ -0,0 +1,62 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.util.function.Supplier; +import net.snowflake.client.jdbc.cloud.storage.floe.aead.AeadProvider; +import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm; + +public enum Aead { + AES_GCM_256((byte) 0, "AES", "AES/GCM/NoPadding", 32, 12, 16, () -> new Gcm(16)); + + private byte id; + private String jceKeyTypeName; + private String jceFullName; + private int keyLength; + private int ivLength; + private int authTagLength; + private Supplier aeadProvider; + + Aead( + byte id, + String jceKeyTypeName, + String jceFullName, + int keyLength, + int ivLength, + int authTagLength, + Supplier aeadProvider) { + this.jceKeyTypeName = jceKeyTypeName; + this.jceFullName = jceFullName; + this.keyLength = keyLength; + this.id = id; + this.ivLength = ivLength; + this.authTagLength = authTagLength; + this.aeadProvider = aeadProvider; + } + + byte getId() { + return id; + } + + public String getJceKeyTypeName() { + return jceKeyTypeName; + } + + String getJceFullName() { + return jceFullName; + } + + int getKeyLength() { + return keyLength; + } + + int getIvLength() { + return ivLength; + } + + int getAuthTagLength() { + return authTagLength; + } + + AeadProvider getAeadProvider() { + return aeadProvider.get(); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java new file mode 100644 index 000000000..36bf52b4c --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadAad.java @@ -0,0 +1,26 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +class AeadAad { + private final byte[] bytes; + + private AeadAad(long segmentCounter, byte terminalityByte) { + ByteBuffer buf = ByteBuffer.allocate(9); + buf.putLong(segmentCounter); + buf.put(terminalityByte); + this.bytes = buf.array(); + } + + static AeadAad nonTerminal(long segmentCounter) { + return new AeadAad(segmentCounter, (byte) 0); + } + + static AeadAad terminal(long segmentCounter) { + return new AeadAad(segmentCounter, (byte) 1); + } + + byte[] getBytes() { + return bytes; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java new file mode 100644 index 000000000..471fa7204 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadIv.java @@ -0,0 +1,25 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +class AeadIv { + private final byte[] bytes; + + AeadIv(byte[] bytes) { + this.bytes = bytes; + } + + static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) { + return new AeadIv(floeRandom.ofLength(ivLength)); + } + + static AeadIv from(ByteBuffer buffer, int ivLength) { + byte[] bytes = new byte[ivLength]; + buffer.get(bytes); + return new AeadIv(bytes); + } + + byte[] getBytes() { + return bytes; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java new file mode 100644 index 000000000..bfbd01976 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/AeadKey.java @@ -0,0 +1,15 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; + +class AeadKey { + private final SecretKey key; + + AeadKey(SecretKey key) { + this.key = key; + } + + SecretKey getKey() { + return key; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java new file mode 100644 index 000000000..aa00aa8cd --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/BaseSegmentProcessor.java @@ -0,0 +1,74 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.io.Closeable; +import java.io.IOException; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; + +abstract class BaseSegmentProcessor implements Closeable { + protected static final int NON_TERMINAL_SEGMENT_SIZE_MARKER = -1; + protected static final int headerTagLength = 32; + + protected final FloeParameterSpec parameterSpec; + protected final FloeKey floeKey; + protected final FloeAad floeAad; + + protected final KeyDerivator keyDerivator; + + private AeadKey currentAeadKey; + + private boolean isClosed; + private boolean completedExceptionally; + + BaseSegmentProcessor(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) { + this.parameterSpec = parameterSpec; + this.floeKey = floeKey; + this.floeAad = floeAad; + this.keyDerivator = new KeyDerivator(parameterSpec); + } + + protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) { + if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) { + // we don't need masking, because we derive a new key only when key rotation happens + currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter); + } + return currentAeadKey; + } + + private AeadKey deriveKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) { + byte[] keyBytes = + keyDerivator.hkdfExpand( + floeKey, + floeIv, + floeAad, + new DekTagFloePurpose(segmentCounter), + parameterSpec.getAead().getKeyLength()); + SecretKey key = new SecretKeySpec(keyBytes, parameterSpec.getAead().getJceKeyTypeName()); + return new AeadKey(key); + } + + protected void closeInternal() { + isClosed = true; + } + + protected void markAsCompletedExceptionally() { + completedExceptionally = true; + } + + protected void assertNotClosed() { + if (isClosed) { + throw new IllegalStateException("stream has already been closed"); + } + } + + @Override + public void close() throws IOException { + if (!isClosed && !completedExceptionally) { + throw new IllegalStateException("last segment was not processed"); + } + } + + protected boolean isClosed() { + return isClosed; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Floe.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Floe.java new file mode 100644 index 000000000..b3147a097 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Floe.java @@ -0,0 +1,23 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; + +public class Floe { + private final FloeParameterSpec parameterSpec; + + private Floe(FloeParameterSpec parameterSpec) { + this.parameterSpec = parameterSpec; + } + + public static Floe getInstance(FloeParameterSpec parameterSpec) { + return new Floe(parameterSpec); + } + + public FloeEncryptor createEncryptor(SecretKey key, byte[] aad) { + return new FloeEncryptorImpl(parameterSpec, new FloeKey(key), new FloeAad(aad)); + } + + public FloeDecryptor createDecryptor(SecretKey key, byte[] aad, byte[] floeHeader) { + return new FloeDecryptorImpl(parameterSpec, new FloeKey(key), new FloeAad(aad), floeHeader); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java new file mode 100644 index 000000000..f135d9b68 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeAad.java @@ -0,0 +1,16 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.util.Optional; + +class FloeAad { + private static final byte[] EMPTY_AAD = new byte[0]; + private final byte[] aad; + + FloeAad(byte[] aad) { + this.aad = Optional.ofNullable(aad).orElse(EMPTY_AAD); + } + + byte[] getBytes() { + return aad; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java new file mode 100644 index 000000000..986a738fe --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptor.java @@ -0,0 +1,7 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +public interface FloeDecryptor extends AutoCloseable { + byte[] processSegment(byte[] ciphertext); + + boolean isClosed(); +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java new file mode 100644 index 000000000..ad440f8de --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImpl.java @@ -0,0 +1,152 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; +import java.util.Arrays; +import net.snowflake.client.jdbc.cloud.storage.floe.aead.AeadProvider; + +class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor { + private final FloeIv floeIv; + private final AeadProvider aeadProvider; + + FloeDecryptorImpl( + FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad, byte[] floeHeaderAsBytes) { + super(parameterSpec, floeKey, floeAad); + byte[] encodedParams = this.parameterSpec.paramEncode(); + if (floeHeaderAsBytes.length + != encodedParams.length + + this.parameterSpec.getFloeIvLength().getLength() + + headerTagLength) { + throw new IllegalArgumentException("invalid header length"); + } + ByteBuffer floeHeader = ByteBuffer.wrap(floeHeaderAsBytes); + + byte[] encodedParamsFromHeader = new byte[10]; + floeHeader.get(encodedParamsFromHeader, 0, encodedParamsFromHeader.length); + if (!Arrays.equals(encodedParams, encodedParamsFromHeader)) { + throw new IllegalArgumentException("invalid parameters header"); + } + + byte[] floeIvBytes = new byte[this.parameterSpec.getFloeIvLength().getLength()]; + floeHeader.get(floeIvBytes, 0, floeIvBytes.length); + this.floeIv = new FloeIv(floeIvBytes); + this.aeadProvider = parameterSpec.getAead().getAeadProvider(); + + byte[] headerTagFromHeader = new byte[headerTagLength]; + floeHeader.get(headerTagFromHeader, 0, headerTagFromHeader.length); + + byte[] headerTag = + keyDerivator.hkdfExpand( + this.floeKey, floeIv, this.floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength); + if (!Arrays.equals(headerTag, headerTagFromHeader)) { + throw new IllegalArgumentException("invalid header tag"); + } + } + + private long segmentCounter; + + @Override + public byte[] processSegment(byte[] input) { + assertNotClosed(); + ByteBuffer inputBuffer = ByteBuffer.wrap(input); + try { + if (isLastSegment(inputBuffer)) { + return processLastSegment(inputBuffer); + } else { + return processNonLastSegment(inputBuffer); + } + } catch (Exception e) { + markAsCompletedExceptionally(); + throw e; + } + } + + private boolean isLastSegment(ByteBuffer inputBuffer) { + int segmentSizeMarker = inputBuffer.getInt(); + try { + return segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER; + } finally { + inputBuffer.rewind(); + } + } + + private byte[] processNonLastSegment(ByteBuffer inputBuf) { + try { + verifyNonLastSegmentLength(inputBuf); + verifySegmentSizeMarker(inputBuf); + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter); + byte[] ciphertext = new byte[inputBuf.remaining()]; + inputBuf.get(ciphertext); + byte[] decrypted = + aeadProvider.decrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext); + segmentCounter++; + return decrypted; + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + private void verifyNonLastSegmentLength(ByteBuffer inputBuf) { + if (inputBuf.capacity() != parameterSpec.getEncryptedSegmentLength()) { + throw new IllegalArgumentException( + String.format( + "segment length mismatch, expected %d, got %d", + parameterSpec.getEncryptedSegmentLength(), inputBuf.capacity())); + } + } + + private void verifySegmentSizeMarker(ByteBuffer inputBuf) { + int segmentSizeMarker = inputBuf.getInt(); + if (segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER) { + throw new IllegalArgumentException( + String.format( + "segment length marker mismatch, expected: %d, got: %d", + NON_TERMINAL_SEGMENT_SIZE_MARKER, segmentSizeMarker)); + } + } + + private byte[] processLastSegment(ByteBuffer inputBuf) { + verifyLastSegmentLength(inputBuf); + verifyLastSegmentSizeMarker(inputBuf); + try { + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.terminal(segmentCounter); + byte[] ciphertext = new byte[inputBuf.remaining()]; + inputBuf.get(ciphertext); + byte[] decrypted = + aeadProvider.decrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext); + closeInternal(); + return decrypted; + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + private void verifyLastSegmentLength(ByteBuffer inputBuf) { + if (inputBuf.capacity() + < 4 + parameterSpec.getAead().getIvLength() + parameterSpec.getAead().getAuthTagLength()) { + throw new IllegalArgumentException("last segment is too short"); + } + if (inputBuf.capacity() > parameterSpec.getEncryptedSegmentLength()) { + throw new IllegalArgumentException("last segment is too long"); + } + } + + private void verifyLastSegmentSizeMarker(ByteBuffer inputBuf) { + int segmentLengthFromSegment = inputBuf.getInt(); + if (segmentLengthFromSegment != inputBuf.capacity()) { + throw new IllegalArgumentException( + String.format( + "last segment length marker mismatch, expected: %d, got: %d", + inputBuf.capacity(), segmentLengthFromSegment)); + } + } + + @Override + public boolean isClosed() { + return super.isClosed(); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java new file mode 100644 index 000000000..911ba2e81 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptor.java @@ -0,0 +1,11 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +public interface FloeEncryptor extends AutoCloseable { + byte[] processSegment(byte[] plaintext); + + byte[] processLastSegment(byte[] plaintext); + + byte[] getHeader(); + + boolean isClosed(); +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java new file mode 100644 index 000000000..a7677fc9a --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImpl.java @@ -0,0 +1,144 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; +import java.security.GeneralSecurityException; +import net.snowflake.client.jdbc.cloud.storage.floe.aead.AeadProvider; + +class FloeEncryptorImpl extends BaseSegmentProcessor implements FloeEncryptor { + private final FloeIv floeIv; + private final AeadProvider aeadProvider; + private AeadKey currentAeadKey; + + private long segmentCounter; + + private final byte[] header; + + FloeEncryptorImpl(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) { + super(parameterSpec, floeKey, floeAad); + this.floeIv = + FloeIv.generateRandom(parameterSpec.getFloeRandom(), parameterSpec.getFloeIvLength()); + this.aeadProvider = parameterSpec.getAead().getAeadProvider(); + this.header = buildHeader(); + } + + private byte[] buildHeader() { + byte[] parametersEncoded = parameterSpec.paramEncode(); + byte[] floeIvBytes = floeIv.getBytes(); + byte[] headerTag = + keyDerivator.hkdfExpand( + floeKey, floeIv, floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength); + + ByteBuffer result = + ByteBuffer.allocate(parametersEncoded.length + floeIvBytes.length + headerTag.length); + result.put(parametersEncoded); + result.put(floeIvBytes); + result.put(headerTag); + if (result.hasRemaining()) { + throw new IllegalArgumentException("Header is too long"); + } + return result.array(); + } + + @Override + public byte[] getHeader() { + return header; + } + + @Override + public byte[] processSegment(byte[] input) { + assertNotClosed(); + try { + verifySegmentLength(input); + verifyMaxSegmentNumberNotReached(); + try { + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = + AeadIv.generateRandom( + parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter); + // it works as long as AEAD returns auth tag as a part of the ciphertext + byte[] ciphertextWithAuthTag = + aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); + byte[] encoded = segmentToBytes(aeadIv, ciphertextWithAuthTag); + segmentCounter++; + return encoded; + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } catch (Exception e) { + markAsCompletedExceptionally(); + throw e; + } + } + + private void verifySegmentLength(byte[] input) { + if (input.length != parameterSpec.getPlainTextSegmentLength()) { + throw new IllegalArgumentException( + String.format( + "segment length mismatch, expected %d, got %d", + parameterSpec.getPlainTextSegmentLength(), input.length)); + } + } + + private void verifyMaxSegmentNumberNotReached() { + if (segmentCounter >= parameterSpec.getMaxSegmentNumber() - 1) { + throw new IllegalStateException("maximum segment number reached"); + } + } + + private byte[] segmentToBytes(AeadIv aeadIv, byte[] ciphertextWithAuthTag) { + ByteBuffer output = ByteBuffer.allocate(parameterSpec.getEncryptedSegmentLength()); + output.putInt(NON_TERMINAL_SEGMENT_SIZE_MARKER); + output.put(aeadIv.getBytes()); + output.put(ciphertextWithAuthTag); + return output.array(); + } + + @Override + public byte[] processLastSegment(byte[] input) { + assertNotClosed(); + try { + verifyLastSegmentNotEmpty(input); + try { + AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter); + AeadIv aeadIv = + AeadIv.generateRandom( + parameterSpec.getFloeRandom(), parameterSpec.getAead().getIvLength()); + AeadAad aeadAad = AeadAad.terminal(segmentCounter); + byte[] ciphertextWithAuthTag = + aeadProvider.encrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), input); + byte[] lastSegmentBytes = lastSegmentToBytes(aeadIv, ciphertextWithAuthTag); + closeInternal(); + return lastSegmentBytes; + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } catch (Exception e) { + markAsCompletedExceptionally(); + throw e; + } + } + + private byte[] lastSegmentToBytes(AeadIv aeadIv, byte[] ciphertextWithAuthTag) { + int lastSegmentLength = 4 + aeadIv.getBytes().length + ciphertextWithAuthTag.length; + ByteBuffer output = ByteBuffer.allocate(lastSegmentLength); + output.putInt(lastSegmentLength); + output.put(aeadIv.getBytes()); + output.put(ciphertextWithAuthTag); + return output.array(); + } + + private void verifyLastSegmentNotEmpty(byte[] input) { + if (input.length > parameterSpec.getPlainTextSegmentLength()) { + throw new IllegalArgumentException( + String.format( + "last segment is too long, got %d, max is %d", + input.length, parameterSpec.getPlainTextSegmentLength())); + } + } + + @Override + public boolean isClosed() { + return super.isClosed(); + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIv.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIv.java new file mode 100644 index 000000000..1022510b5 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIv.java @@ -0,0 +1,21 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +class FloeIv { + private final byte[] bytes; + + FloeIv(byte[] bytes) { + this.bytes = bytes; + } + + static FloeIv generateRandom(FloeRandom floeRandom, FloeIvLength floeIvLength) { + return new FloeIv(floeRandom.ofLength(floeIvLength.getLength())); + } + + byte[] getBytes() { + return bytes; + } + + int lengthInBytes() { + return bytes.length; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java new file mode 100644 index 000000000..466005265 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeIvLength.java @@ -0,0 +1,13 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +class FloeIvLength { + private final int length; + + FloeIvLength(int length) { + this.length = length; + } + + int getLength() { + return length; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKey.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKey.java new file mode 100644 index 000000000..6b6bf9991 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeKey.java @@ -0,0 +1,15 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import javax.crypto.SecretKey; + +class FloeKey { + private final SecretKey key; + + FloeKey(SecretKey key) { + this.key = key; + } + + SecretKey getKey() { + return key; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java new file mode 100644 index 000000000..bb79314d4 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeParameterSpec.java @@ -0,0 +1,90 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class FloeParameterSpec { + private final Aead aead; + private final Hash hash; + private final int encryptedSegmentLength; + private final FloeIvLength floeIvLength; + private final FloeRandom floeRandom; + private final int keyRotationModulo; + private final long maxSegmentNumber; + + public FloeParameterSpec(Aead aead, Hash hash, int encryptedSegmentLength, int floeIvLength) { + this( + aead, + hash, + encryptedSegmentLength, + new FloeIvLength(floeIvLength), + new SecureFloeRandom(), + 1 << 20, + 1L << 40); + } + + FloeParameterSpec( + Aead aead, + Hash hash, + int encryptedSegmentLength, + FloeIvLength floeIvLength, + FloeRandom floeRandom, + int keyRotationModulo, + long maxSegmentNumber) { + this.aead = aead; + this.hash = hash; + this.encryptedSegmentLength = encryptedSegmentLength; + this.floeIvLength = floeIvLength; + this.floeRandom = floeRandom; + this.keyRotationModulo = keyRotationModulo; + this.maxSegmentNumber = maxSegmentNumber; + if (encryptedSegmentLength <= 0) { + throw new IllegalArgumentException("encryptedSegmentLength must be > 0"); + } + if (floeIvLength.getLength() <= 0) { + throw new IllegalArgumentException("floeIvLength must be > 0"); + } + } + + byte[] paramEncode() { + ByteBuffer result = ByteBuffer.allocate(10).order(ByteOrder.BIG_ENDIAN); + result.put(aead.getId()); + result.put(hash.getId()); + result.putInt(encryptedSegmentLength); + result.putInt(floeIvLength.getLength()); + return result.array(); + } + + public Aead getAead() { + return aead; + } + + public Hash getHash() { + return hash; + } + + public FloeIvLength getFloeIvLength() { + return floeIvLength; + } + + FloeRandom getFloeRandom() { + return floeRandom; + } + + int getEncryptedSegmentLength() { + return encryptedSegmentLength; + } + + int getPlainTextSegmentLength() { + // sizeof(int) == 4, file size is a part of the segment ciphertext + return encryptedSegmentLength - aead.getIvLength() - aead.getAuthTagLength() - 4; + } + + int getKeyRotationModulo() { + return keyRotationModulo; + } + + long getMaxSegmentNumber() { + return maxSegmentNumber; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java new file mode 100644 index 000000000..41fda867a --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloePurpose.java @@ -0,0 +1,39 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; + +interface FloePurpose { + byte[] generate(); +} + +class HeaderTagFloePurpose implements FloePurpose { + private static final byte[] bytes = "HEADER_TAG:".getBytes(StandardCharsets.UTF_8); + + static final HeaderTagFloePurpose INSTANCE = new HeaderTagFloePurpose(); + + private HeaderTagFloePurpose() {} + + @Override + public byte[] generate() { + return bytes; + } +} + +class DekTagFloePurpose implements FloePurpose { + private static final byte[] prefix = "DEK:".getBytes(StandardCharsets.UTF_8); + + private final byte[] bytes; + + DekTagFloePurpose(long segmentCount) { + ByteBuffer buffer = ByteBuffer.allocate(prefix.length + 8 /*size of long*/); + buffer.put(prefix); + buffer.putLong(segmentCount); + this.bytes = buffer.array(); + } + + @Override + public byte[] generate() { + return bytes; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeRandom.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeRandom.java new file mode 100644 index 000000000..a0bd176f1 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeRandom.java @@ -0,0 +1,5 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +interface FloeRandom { + byte[] ofLength(int length); +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Hash.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Hash.java new file mode 100644 index 000000000..45f5b8c3a --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/Hash.java @@ -0,0 +1,21 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +public enum Hash { + SHA384((byte) 0, "HmacSHA384"); + + private byte id; + private final String jceName; + + Hash(byte id, String jceName) { + this.id = id; + this.jceName = jceName; + } + + byte getId() { + return id; + } + + public String getJceName() { + return jceName; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/KeyDerivator.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/KeyDerivator.java new file mode 100644 index 000000000..e3d064ba1 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/KeyDerivator.java @@ -0,0 +1,48 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import javax.crypto.Mac; + +class KeyDerivator { + private final FloeParameterSpec parameterSpec; + + KeyDerivator(FloeParameterSpec parameterSpec) { + this.parameterSpec = parameterSpec; + } + + byte[] hkdfExpand( + FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, FloePurpose purpose, int length) { + byte[] encodedParams = parameterSpec.paramEncode(); + byte[] purposeBytes = purpose.generate(); + ByteBuffer info = + ByteBuffer.allocate( + encodedParams.length + + floeIv.getBytes().length + + purposeBytes.length + + floeAad.getBytes().length); + info.put(encodedParams); + info.put(floeIv.getBytes()); + info.put(purposeBytes); + info.put(floeAad.getBytes()); + return hkdfExpandInternal(parameterSpec.getHash(), floeKey, info.array(), length); + } + + private byte[] hkdfExpandInternal(Hash hash, FloeKey prk, byte[] info, int len) { + try { + Mac mac = Mac.getInstance(hash.getJceName()); + mac.init(prk.getKey()); + mac.update(info); + mac.update((byte) 1); + byte[] bytes = mac.doFinal(); + if (bytes.length != len) { + return Arrays.copyOf(bytes, len); + } + return bytes; + } catch (NoSuchAlgorithmException | InvalidKeyException e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SecureFloeRandom.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SecureFloeRandom.java new file mode 100644 index 000000000..9302c8c83 --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/SecureFloeRandom.java @@ -0,0 +1,15 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.security.SecureRandom; + +class SecureFloeRandom implements FloeRandom { + private static final ThreadLocal random = + ThreadLocal.withInitial(SecureRandom::new); + + @Override + public byte[] ofLength(int length) { + byte[] bytes = new byte[length]; + random.get().nextBytes(bytes); + return bytes; + } +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/AeadProvider.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/AeadProvider.java new file mode 100644 index 000000000..63fee639c --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/AeadProvider.java @@ -0,0 +1,20 @@ +package net.snowflake.client.jdbc.cloud.storage.floe.aead; + +import java.security.GeneralSecurityException; +import javax.crypto.SecretKey; + +// Consideration for implementations: +// 1. Implementations does not have to be thread safe, they are used in FLOE in a thread safe manner +// (FLOE encryptor and decryptor creates their own instances). +// 2. Authentication tag is a part of ciphertext: +// a) For encrypt function - auth tag is returned with ciphertext. +// b) For decrypt function - auth tag is passed with ciphertext. +// As long as it isn't strictly required to be at the end of the ciphertext, it is needed to be in +// the correct position for the underlying algorithm. +public interface AeadProvider { + byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) + throws GeneralSecurityException; + + byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) + throws GeneralSecurityException; +} diff --git a/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java new file mode 100644 index 000000000..616ff2aea --- /dev/null +++ b/src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/aead/Gcm.java @@ -0,0 +1,52 @@ +package net.snowflake.client.jdbc.cloud.storage.floe.aead; + +import java.security.GeneralSecurityException; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import javax.crypto.BadPaddingException; +import javax.crypto.Cipher; +import javax.crypto.IllegalBlockSizeException; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.spec.GCMParameterSpec; + +// This class is not thread safe! +// But as long as it is used only for FLOE, it is fine, as FLOE instance keeps its own instance of +// GCM. +public class Gcm implements AeadProvider { + private final Cipher keyCipher; + private final int tagLengthInBits; + + public Gcm(int tagLengthInBytes) { + try { + keyCipher = Cipher.getInstance("AES/GCM/NoPadding"); + this.tagLengthInBits = tagLengthInBytes * 8; + } catch (NoSuchAlgorithmException | NoSuchPaddingException e) { + throw new ExceptionInInitializerError(e); + } + } + + @Override + public byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) + throws GeneralSecurityException { + return process(key, iv, aad, plaintext, Cipher.ENCRYPT_MODE); + } + + @Override + public byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) + throws GeneralSecurityException { + return process(key, iv, aad, ciphertext, Cipher.DECRYPT_MODE); + } + + private byte[] process(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext, int encryptMode) + throws InvalidKeyException, InvalidAlgorithmParameterException, IllegalBlockSizeException, + BadPaddingException { + GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(tagLengthInBits, iv); + keyCipher.init(encryptMode, key, gcmParameterSpec); + if (aad != null) { + keyCipher.updateAAD(aad); + } + return keyCipher.doFinal(plaintext); + } +} diff --git a/src/test/java/net/snowflake/client/AbstractDriverIT.java b/src/test/java/net/snowflake/client/AbstractDriverIT.java index 3104ce7e9..c75b81818 100644 --- a/src/test/java/net/snowflake/client/AbstractDriverIT.java +++ b/src/test/java/net/snowflake/client/AbstractDriverIT.java @@ -324,6 +324,10 @@ public static Connection getConnection( properties.put("internal", Boolean.TRUE.toString()); // TODO: do we need this? properties.put("insecureMode", false); // use OCSP for all tests. + properties.put("useProxy", "true"); + properties.put("proxyHost", "localhost"); + properties.put("proxyPort", "8080"); + if (injectSocketTimeout > 0) { properties.put("injectSocketTimeout", String.valueOf(injectSocketTimeout)); } diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java new file mode 100644 index 000000000..c127ee069 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeDecryptorImplTest.java @@ -0,0 +1,137 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import javax.crypto.AEADBadTagException; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.junit.jupiter.api.Test; + +class FloeDecryptorImplTest { + private final SecretKey secretKey = new SecretKeySpec(new byte[32], "AES"); + private final byte[] aad = "Test AAD".getBytes(StandardCharsets.UTF_8); + + @Test + void shouldDecryptCiphertext() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] firstSegment = encryptor.processSegment(new byte[8]); + byte[] lastSegment = encryptor.processLastSegment(new byte[4]); + + assertArrayEquals(new byte[8], decryptor.processSegment(firstSegment)); + assertArrayEquals(new byte[4], decryptor.processSegment(lastSegment)); + } + } + + @Test + void shouldDecryptLastSegmentZeroLength() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] lastSegment = encryptor.processLastSegment(new byte[0]); + assertArrayEquals(new byte[0], decryptor.processSegment(lastSegment)); + } + } + + @Test + void shouldDecryptLastSegmentFullLength() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] lastSegment = encryptor.processLastSegment(new byte[8]); + assertArrayEquals(new byte[8], decryptor.processSegment(lastSegment)); + } + } + + @Test + void shouldThrowExceptionIfSegmentLengthIsMismatched() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] ciphertext = encryptor.processSegment(new byte[8]); + byte[] prunedCiphertext = new byte[12]; + ByteBuffer.wrap(ciphertext).get(prunedCiphertext); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(prunedCiphertext)); + assertEquals("segment length mismatch, expected 40, got 12", e.getMessage()); + byte[] extendedCiphertext = new byte[1024]; + ByteBuffer.wrap(extendedCiphertext).put(ciphertext); + e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(extendedCiphertext)); + assertEquals("segment length mismatch, expected 40, got 1024", e.getMessage()); + encryptor.processLastSegment(new byte[4]); + } + } + + @Test + void shouldThrowExceptionIfLastSegmentLengthIsMismatched() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + encryptor.processLastSegment(new byte[4]); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(new byte[12])); + assertEquals("last segment is too short", e.getMessage()); + e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(new byte[1024])); + assertEquals("last segment is too long", e.getMessage()); + } + } + + @Test + void shouldThrowExceptionIfLastSegmentLengthMarkerIsNotMinusOne() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + encryptor.processLastSegment(new byte[4]); + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> decryptor.processSegment(new byte[40])); + assertEquals("last segment length marker mismatch, expected: 40, got: 0", e.getMessage()); + } + } + + @Test + void shouldThrowExceptionIfSegmentIsTampered() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] ciphertext = encryptor.processLastSegment(new byte[8]); + ciphertext[39]++; + RuntimeException e = + assertThrows(RuntimeException.class, () -> decryptor.processSegment(ciphertext)); + assertEquals(e.getCause().getClass(), AEADBadTagException.class); + } + } + + @Test + void shouldThrowExceptionIfSegmentAreOutOfOrder() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] ciphertext1 = encryptor.processSegment(new byte[8]); + byte[] ciphertext2 = encryptor.processSegment(new byte[8]); + encryptor.processLastSegment(new byte[4]); + RuntimeException e = + assertThrows(RuntimeException.class, () -> decryptor.processSegment(ciphertext2)); + assertEquals(e.getCause().getClass(), AEADBadTagException.class); + } + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java new file mode 100644 index 000000000..4681176e7 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeEncryptorImplTest.java @@ -0,0 +1,198 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import static org.junit.Assert.assertThrows; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +class FloeEncryptorImplTest { + byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); + SecretKey secretKey = new SecretKeySpec(new byte[32], "FLOE"); + + @Test + void shouldCreateCorrectHeader() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 12345678, + new FloeIvLength(4), + new IncrementingFloeRandom(17), + 4, + 1L << 40); + parameterSpec.getFloeRandom().ofLength(4); // just to trigger incrementation + FloeKey floeKey = new FloeKey(new SecretKeySpec(new byte[32], "FLOE")); + FloeAad floeAad = new FloeAad("test aad".getBytes(StandardCharsets.UTF_8)); + try (FloeEncryptor encryptor = new FloeEncryptorImpl(parameterSpec, floeKey, floeAad)) { + byte[] header = encryptor.getHeader(); + // AEAD ID + assertEquals(Aead.AES_GCM_256.getId(), header[0]); + // HASH ID + assertEquals(Hash.SHA384.getId(), header[1]); + // Segment length in BE + // 12345678(10) = BC614E(16) + assertEquals(0, header[2]); + assertEquals((byte) 188, header[3]); + assertEquals((byte) 97, header[4]); + assertEquals((byte) 78, header[5]); + // FLOE IV length in BE + // 4(10) = 4(16) = 00,00,00,04 + assertEquals(0, header[6]); + assertEquals(0, header[7]); + assertEquals(0, header[8]); + assertEquals(4, header[9]); + // FLOE IV + assertEquals(0, header[10]); + assertEquals(0, header[11]); + assertEquals(0, header[12]); + assertEquals(18, header[13]); + + close(encryptor); + } + } + + private static byte[] close(FloeEncryptor encryptor) { + return encryptor.processLastSegment(new byte[0]); + } + + @Test + void testEncryptionMatchesReference() throws Exception { + List referenceCiphertextSegments = + Arrays.asList( + "ffffffff0000000100000000000000000100007f5713b9827bb806318311fcde197146a144c6b485", // pragma: allowlist secret + "ffffffff000000020000000000000000f926dfc0a0bac6263d1634ad9a72f86900872033a271a037", // pragma: allowlist secret + "ffffffff00000003000000000000000080df8fdee872febe574c2b8df0bb34b3fb25bfc5802703a2", // pragma: allowlist secret + "ffffffff000000040000000000000000f4d81083e57451dbfa538827942245019b8bc3354ecc31e0", // pragma: allowlist secret + "ffffffff000000050000000000000000d91b774b5b460bd665910114e155f1cbc55a9a262a54f65e", // pragma: allowlist secret + "ffffffff000000060000000000000000ec723f3807eb71ea42ff03f5420daf34e1a8f4fb58931db1", // pragma: allowlist secret + "ffffffff00000007000000000000000072960c06ec19ce94c27c9fc72d79164f187f37e86325d849", // pragma: allowlist secret + "ffffffff000000080000000000000000c00a40fb140d797da818ab57399cb986bddddd174b8d3d6a", // pragma: allowlist secret + "ffffffff000000090000000000000000065e959cd1ffa521896fb54949a57ad1c1f8291a531c6d60", // pragma: allowlist secret + "ffffffff0000000a0000000000000000dfde3da3f67a081fb31229ac11e43a629ed120fbf9942513" // pragma: allowlist secret + ); + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(0), + 4, + 1L << 40); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] testData = new byte[8]; + for (String referenceCiphertextSegment : referenceCiphertextSegments) { + byte[] ciphertextBytes = encryptor.processSegment(testData); + String ciphertextHex = toHex(ciphertextBytes); + assertEquals(referenceCiphertextSegment, ciphertextHex); + byte[] plaintextBytes = decryptor.processSegment(ciphertextBytes); + assertArrayEquals(testData, plaintextBytes); + } + + byte[] lastSegment = encryptor.processLastSegment(new byte[0]); + decryptor.processSegment(lastSegment); + } + } + + @Test + void shouldThrowExceptionOnMaxSegmentReached() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, Hash.SHA384, 40, new FloeIvLength(32), new SecureFloeRandom(), 20, 3); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + byte[] plaintext = new byte[8]; + encryptor.processSegment(plaintext); + encryptor.processSegment(plaintext); + assertThrows(IllegalStateException.class, () -> encryptor.processSegment(plaintext)); + assertDoesNotThrow(() -> encryptor.processLastSegment(plaintext)); + } + } + + @Test + void shouldThrowExceptionIfPlaintextIsTooShort() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[0])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 0"); + e = assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[7])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 7"); + + close(encryptor); + } + } + + @Test + void shouldThrowEncryptionIfPlaintextIsTooLong() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + IllegalArgumentException e = + assertThrows(IllegalArgumentException.class, () -> encryptor.processSegment(new byte[9])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 9"); + e = + assertThrows( + IllegalArgumentException.class, () -> encryptor.processSegment(new byte[1024])); + assertEquals(e.getMessage(), "segment length mismatch, expected 8, got 1024"); + + e = + assertThrows( + IllegalArgumentException.class, () -> encryptor.processLastSegment(new byte[9])); + assertEquals(e.getMessage(), "last segment is too long, got 9, max is 8"); + + close(encryptor); + } + } + + @ParameterizedTest + @ValueSource(ints = {0, 8}) + void shouldAcceptSegmentWithCorrectSize(int lastSegmentSize) throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + assertDoesNotThrow(() -> encryptor.processSegment(new byte[8])); + assertDoesNotThrow(() -> encryptor.processLastSegment(new byte[lastSegmentSize])); + } + } + + @Test + void shouldNotAcceptNewSegmentsAfterLastOneIsProcessed() throws Exception { + FloeParameterSpec parameterSpec = new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 40, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + assertFalse(encryptor.isClosed()); + encryptor.processLastSegment(new byte[4]); + assertTrue(encryptor.isClosed()); + IllegalStateException e = + assertThrows(IllegalStateException.class, () -> encryptor.processSegment(new byte[4])); + assertEquals("stream has already been closed", e.getMessage()); + e = + assertThrows( + IllegalStateException.class, () -> encryptor.processLastSegment(new byte[4])); + assertEquals("stream has already been closed", e.getMessage()); + } + } + + String toHex(byte[] input) { + StringBuilder result = new StringBuilder(); + for (byte b : input) { + result.append(String.format("%02x", b)); + } + return result.toString(); + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java new file mode 100644 index 000000000..a6914acf5 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeTest.java @@ -0,0 +1,236 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import static com.amazonaws.util.BinaryUtils.toHex; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.charset.StandardCharsets; +import java.security.SecureRandom; +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; + +class FloeTest { + byte[] aad = "This is AAD".getBytes(StandardCharsets.UTF_8); + SecretKey secretKey = new SecretKeySpec(new byte[32], "FLOE"); + + @Nested + class HeaderTests { + @Test + void validateHeaderMatchesForEncryptionAndDecryption() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + decryptor.processSegment(encryptor.processLastSegment(new byte[0])); + } + } + + @Test + void validateHeaderDoesNotMatchInParams() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + byte[] header = encryptor.getHeader(); + header[0] = 12; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid parameters header"); + encryptor.processLastSegment(new byte[0]); + } + } + + @Test + void validateHeaderDoesNotMatchInIV() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 4); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + byte[] header = encryptor.getHeader(); + header[11]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + encryptor.processLastSegment(new byte[0]); + } + } + + @Test + void validateHeaderDoesNotMatchInHeaderTag() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 4096, 4); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad)) { + byte[] header = encryptor.getHeader(); + header[header.length - 3]++; + IllegalArgumentException e = + assertThrows( + IllegalArgumentException.class, () -> floe.createDecryptor(secretKey, aad, header)); + assertEquals(e.getMessage(), "invalid header tag"); + encryptor.processLastSegment(new byte[0]); + } + } + } + + @Nested + class SegmentTests { + + @Test + void testSegmentEncryptedAndDecrypted() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(678765), + 4, + 1L << 40); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] testData = new byte[8]; + byte[] ciphertext = encryptor.processLastSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } + } + + @Test + void testSegmentEncryptedAndDecryptedWithRandomData() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(37665), + 4, + 1L << 40); + Floe floe = Floe.getInstance(parameterSpec); + byte[] ciphertext; + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] testData = new byte[8]; + new SecureRandom().nextBytes(testData); + ciphertext = encryptor.processLastSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } + } + + @Test + void testSegmentEncryptedAndDecryptedWithDerivedKeyRotation() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(6546), + 4, + 1L << 40); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] testData = new byte[8]; + for (int i = 0; i < 10; i++) { + byte[] ciphertext = encryptor.processSegment(testData); + byte[] result = decryptor.processSegment(ciphertext); + assertArrayEquals(testData, result); + } + byte[] ciphertext = encryptor.processLastSegment(testData); + decryptor.processSegment(ciphertext); + } + } + } + + @Nested + class LastSegmentTests { + @Test + void testLastSegmentEncryptedAndDecrypted() throws Exception { + FloeParameterSpec parameterSpec = + new FloeParameterSpec(Aead.AES_GCM_256, Hash.SHA384, 1024, 32); + Floe floe = Floe.getInstance(parameterSpec); + try (FloeEncryptor encryptor = floe.createEncryptor(secretKey, aad); + FloeDecryptor decryptor = floe.createDecryptor(secretKey, aad, encryptor.getHeader())) { + byte[] plaintext = new byte[3]; + byte[] encrypted = encryptor.processLastSegment(plaintext); + byte[] decrypted = decryptor.processSegment(encrypted); + assertArrayEquals(plaintext, decrypted); + } + } + + @Test + void testDecryptLastSegmentWithReferenceDataWithEmptyLastSegment() throws Exception { + FloeParameterSpec floeParameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(0), + 16, + 1L << 40); + Floe floe = Floe.getInstance(floeParameterSpec); + try (FloeEncryptor encryptor = + floe.createEncryptor(new SecretKeySpec(new byte[16], "FLOE"), new byte[0]); + FloeDecryptor decryptor = + floe.createDecryptor(secretKey, new byte[0], encryptor.getHeader())) { + byte[] plaintext = new byte[8]; + byte[] encryptedFirstSegment = encryptor.processSegment(plaintext); + byte[] encryptedLastSegment = encryptor.processLastSegment(new byte[0]); + + assertEquals( + toHex(encryptedFirstSegment), + "ffffffff0000000100000000000000002dd631464f6a583369b74f546adfa4db9a838732d6338ef4"); // pragma: allowlist secret + assertEquals( + toHex(encryptedLastSegment), + "000000200000000200000000000000004a4082e6b94a8b1b2053f40879402df1"); // pragma: + // allowlist secret + + assertArrayEquals(plaintext, decryptor.processSegment(encryptedFirstSegment)); + assertArrayEquals(new byte[0], decryptor.processSegment(encryptedLastSegment)); + } + } + + @Test + void testDecryptLastSegmentWithReferenceDataWithNonEmptyLastSegment() throws Exception { + FloeParameterSpec floeParameterSpec = + new FloeParameterSpec( + Aead.AES_GCM_256, + Hash.SHA384, + 40, + new FloeIvLength(32), + new IncrementingFloeRandom(0), + 16, + 1L << 40); + Floe floe = Floe.getInstance(floeParameterSpec); + try (FloeEncryptor encryptor = + floe.createEncryptor(new SecretKeySpec(new byte[16], "FLOE"), new byte[0]); + FloeDecryptor decryptor = + floe.createDecryptor(secretKey, new byte[0], encryptor.getHeader())) { + byte[] plaintext = new byte[8]; + byte[] encryptedFirstSegment = encryptor.processSegment(plaintext); + byte[] encryptedLastSegment = encryptor.processLastSegment(plaintext); + + assertEquals( + toHex(encryptedFirstSegment), + "ffffffff0000000100000000000000002dd631464f6a583369b74f546adfa4db9a838732d6338ef4"); // pragma: allowlist secret + assertEquals( + toHex(encryptedLastSegment), + "000000280000000200000000000000003b14259ad693c7df7a2d6b9d9912dc70a81205d41ac43a41"); // pragma: allowlist secret + + assertArrayEquals(plaintext, decryptor.processSegment(encryptedFirstSegment)); + assertArrayEquals(plaintext, decryptor.processSegment(encryptedLastSegment)); + } + } + } +} diff --git a/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java new file mode 100644 index 000000000..fb5a152d1 --- /dev/null +++ b/src/test/java/net/snowflake/client/jdbc/cloud/storage/floe/IncrementingFloeRandom.java @@ -0,0 +1,18 @@ +package net.snowflake.client.jdbc.cloud.storage.floe; + +import java.nio.ByteBuffer; + +public class IncrementingFloeRandom implements FloeRandom { + private int seed; + + public IncrementingFloeRandom(int seed) { + this.seed = seed; + } + + @Override + public byte[] ofLength(int length) { + ByteBuffer buffer = ByteBuffer.allocate(length); + buffer.putInt(seed++); + return buffer.array(); + } +}