Skip to content

Commit 52d061f

Browse files
committed
Implement processing segments
1 parent 2634bdd commit 52d061f

22 files changed

+515
-133
lines changed

src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java

+21-20
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,18 @@
33
*/
44
package net.snowflake.client.jdbc.cloud.storage;
55

6-
import static java.nio.file.StandardOpenOption.CREATE;
7-
import static java.nio.file.StandardOpenOption.READ;
6+
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
7+
import net.snowflake.client.jdbc.MatDesc;
8+
import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial;
89

10+
import javax.crypto.BadPaddingException;
11+
import javax.crypto.Cipher;
12+
import javax.crypto.CipherInputStream;
13+
import javax.crypto.IllegalBlockSizeException;
14+
import javax.crypto.NoSuchPaddingException;
15+
import javax.crypto.SecretKey;
16+
import javax.crypto.spec.GCMParameterSpec;
17+
import javax.crypto.spec.SecretKeySpec;
918
import java.io.File;
1019
import java.io.FileOutputStream;
1120
import java.io.IOException;
@@ -18,23 +27,15 @@
1827
import java.security.NoSuchAlgorithmException;
1928
import java.security.SecureRandom;
2029
import java.util.Base64;
21-
import javax.crypto.BadPaddingException;
22-
import javax.crypto.Cipher;
23-
import javax.crypto.CipherInputStream;
24-
import javax.crypto.IllegalBlockSizeException;
25-
import javax.crypto.NoSuchPaddingException;
26-
import javax.crypto.SecretKey;
27-
import javax.crypto.spec.GCMParameterSpec;
28-
import javax.crypto.spec.SecretKeySpec;
29-
import net.snowflake.client.jdbc.MatDesc;
30-
import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial;
3130

32-
class GcmEncryptionProvider {
31+
import static java.nio.file.StandardOpenOption.CREATE;
32+
import static java.nio.file.StandardOpenOption.READ;
33+
34+
@SnowflakeJdbcInternalApi
35+
public class GcmEncryptionProvider {
3336
private static final int TAG_LENGTH_IN_BITS = 128;
3437
private static final int IV_LENGTH_IN_BYTES = 12;
3538
private static final String AES = "AES";
36-
private static final String FILE_CIPHER = "AES/GCM/NoPadding";
37-
private static final String KEY_CIPHER = "AES/GCM/NoPadding";
3839
private static final int BUFFER_SIZE = 8 * 1024 * 1024; // 2 MB
3940
private static final ThreadLocal<SecureRandom> random =
4041
ThreadLocal.withInitial(SecureRandom::new);
@@ -85,7 +86,7 @@ private static byte[] encryptKey(byte[] kekBytes, byte[] keyBytes, byte[] keyIvD
8586
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
8687
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
8788
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, keyIvData);
88-
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
89+
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
8990
keyCipher.init(Cipher.ENCRYPT_MODE, kek, gcmParameterSpec);
9091
if (aad != null) {
9192
keyCipher.updateAAD(aad);
@@ -99,7 +100,7 @@ private static CipherInputStream encryptContent(
99100
NoSuchAlgorithmException {
100101
SecretKey fileKey = new SecretKeySpec(keyBytes, 0, keyBytes.length, AES);
101102
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, dataIvBytes);
102-
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
103+
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
103104
fileCipher.init(Cipher.ENCRYPT_MODE, fileKey, gcmParameterSpec);
104105
if (aad != null) {
105106
fileCipher.updateAAD(aad);
@@ -172,7 +173,7 @@ private static CipherInputStream decryptContentFromStream(
172173
NoSuchAlgorithmException {
173174
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
174175
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
175-
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
176+
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
176177
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
177178
if (aad != null) {
178179
fileCipher.updateAAD(aad);
@@ -187,7 +188,7 @@ private static void decryptContentFromFile(
187188
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
188189
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, cekIvBytes);
189190
byte[] buffer = new byte[BUFFER_SIZE];
190-
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
191+
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
191192
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
192193
if (aad != null) {
193194
fileCipher.updateAAD(aad);
@@ -215,7 +216,7 @@ private static byte[] decryptKey(byte[] kekBytes, byte[] ivBytes, byte[] keyByte
215216
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
216217
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
217218
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
218-
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
219+
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
219220
keyCipher.init(Cipher.DECRYPT_MODE, kek, gcmParameterSpec);
220221
if (aad != null) {
221222
keyCipher.updateAAD(aad);
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,49 @@
11
package net.snowflake.client.jdbc.cloud.storage.floe;
22

3+
import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm;
4+
35
public enum Aead {
4-
AES_GCM_128((byte) 0),
5-
AES_GCM_256((byte) 1);
6+
// TODO confirm id
7+
AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, new Gcm(16)),
8+
AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, new Gcm(16));
69

710
private byte id;
11+
private String jceName;
12+
private int keyLength;
13+
private int ivLength;
14+
private int authTagLength;
15+
private AeadProvider aeadProvider;
816

9-
Aead(byte id) {
17+
Aead(byte id, String jceName, int keyLength, int ivLength, int authTagLength, AeadProvider aeadProvider) {
18+
this.jceName = jceName;
19+
this.keyLength = keyLength;
1020
this.id = id;
21+
this.ivLength = ivLength;
22+
this.authTagLength = authTagLength;
23+
this.aeadProvider = aeadProvider;
1124
}
1225

1326
byte getId() {
1427
return id;
1528
}
29+
30+
String getJceName() {
31+
return jceName;
32+
}
33+
34+
int getKeyLength() {
35+
return keyLength;
36+
}
37+
38+
int getIvLength() {
39+
return ivLength;
40+
}
41+
42+
int getAuthTagLength() {
43+
return authTagLength;
44+
}
45+
46+
AeadProvider getAeadProvider() {
47+
return aeadProvider;
48+
}
1649
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import java.nio.ByteBuffer;
4+
5+
class AeadAad {
6+
private final byte[] bytes;
7+
8+
private AeadAad(long segmentCounter, byte terminalityByte) {
9+
ByteBuffer buf = ByteBuffer.allocate(9);
10+
buf.putLong(segmentCounter);
11+
buf.put(terminalityByte);
12+
this.bytes = buf.array();
13+
}
14+
15+
static AeadAad nonTerminal(long segmentCounter) {
16+
return new AeadAad(segmentCounter, (byte) 0);
17+
}
18+
19+
byte[] getBytes() {
20+
return bytes;
21+
}
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import java.nio.ByteBuffer;
4+
5+
class AeadIv {
6+
private final byte[] bytes;
7+
8+
AeadIv(byte[] bytes) {
9+
this.bytes = bytes;
10+
}
11+
12+
public static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) {
13+
return new AeadIv(floeRandom.ofLength(ivLength));
14+
}
15+
16+
public static AeadIv from(ByteBuffer buffer, int ivLength) {
17+
byte[] bytes = new byte[ivLength];
18+
buffer.get(bytes);
19+
return new AeadIv(bytes);
20+
}
21+
22+
byte[] getBytes() {
23+
return bytes;
24+
}
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import javax.crypto.SecretKey;
4+
5+
class AeadKey {
6+
private final SecretKey key;
7+
8+
AeadKey(SecretKey key) {
9+
this.key = key;
10+
}
11+
12+
SecretKey getKey() {
13+
return key;
14+
}
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import javax.crypto.SecretKey;
4+
import java.security.GeneralSecurityException;
5+
6+
public interface AeadProvider {
7+
byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext) throws GeneralSecurityException;
8+
byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext) throws GeneralSecurityException;
9+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import javax.crypto.SecretKey;
4+
import javax.crypto.spec.SecretKeySpec;
5+
6+
abstract class BaseSegmentProcessor {
7+
protected static final int NON_TERMINAL_SEGMENT_SIZE_MARKER = -1;
8+
protected static final int headerTagLength = 32;
9+
10+
protected final FloeParameterSpec parameterSpec;
11+
protected final FloeKey floeKey;
12+
protected final FloeAad floeAad;
13+
14+
protected final KeyDerivator keyDerivator;
15+
16+
private AeadKey currentAeadKey;
17+
18+
BaseSegmentProcessor(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) {
19+
this.parameterSpec = parameterSpec;
20+
this.floeKey = floeKey;
21+
this.floeAad = floeAad;
22+
this.keyDerivator = new KeyDerivator(parameterSpec);
23+
}
24+
25+
protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
26+
if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) {
27+
currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter);
28+
}
29+
return currentAeadKey;
30+
}
31+
32+
private AeadKey deriveKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
33+
byte[] keyBytes = keyDerivator.hkdfExpand(floeKey, floeIv, floeAad, new DekTagFloePurpose(segmentCounter), parameterSpec.getAead().getKeyLength());
34+
SecretKey key = new SecretKeySpec(keyBytes, "AES"); // for now it is safe as we use only AES as AEAD
35+
return new AeadKey(key);
36+
}
37+
}

src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java

-18
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
package net.snowflake.client.jdbc.cloud.storage.floe;
22

3-
public interface FloeDecryptor {}
3+
public interface FloeDecryptor extends SegmentProcessor {
4+
5+
}
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
package net.snowflake.client.jdbc.cloud.storage.floe;
22

33
import java.nio.ByteBuffer;
4+
import java.security.GeneralSecurityException;
45
import java.util.Arrays;
56

6-
public class FloeDecryptorImpl extends FloeBase implements FloeDecryptor {
7+
public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor {
8+
private final FloeIv floeIv;
9+
private long segmentCounter;
10+
711
FloeDecryptorImpl(
812
FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad, byte[] floeHeaderAsBytes) {
913
super(parameterSpec, floeKey, floeAad);
10-
validate(floeHeaderAsBytes);
11-
}
12-
13-
public void validate(byte[] floeHeaderAsBytes) {
14-
byte[] encodedParams = parameterSpec.paramEncode();
14+
byte[] encodedParams = this.parameterSpec.paramEncode();
1515
if (floeHeaderAsBytes.length
16-
!= encodedParams.length + parameterSpec.getFloeIvLength().getLength() + headerTagLength) {
16+
!= encodedParams.length + this.parameterSpec.getFloeIvLength().getLength() + headerTagLength) {
1717
throw new IllegalArgumentException("invalid header length");
1818
}
1919
ByteBuffer floeHeader = ByteBuffer.wrap(floeHeaderAsBytes);
@@ -24,17 +24,48 @@ public void validate(byte[] floeHeaderAsBytes) {
2424
throw new IllegalArgumentException("invalid parameters header");
2525
}
2626

27-
byte[] floeIvBytes = new byte[parameterSpec.getFloeIvLength().getLength()];
27+
byte[] floeIvBytes = new byte[this.parameterSpec.getFloeIvLength().getLength()];
2828
floeHeader.get(floeIvBytes, 0, floeIvBytes.length);
29-
FloeIv floeIv = new FloeIv(floeIvBytes);
29+
this.floeIv = new FloeIv(floeIvBytes);
3030

3131
byte[] headerTagFromHeader = new byte[headerTagLength];
3232
floeHeader.get(headerTagFromHeader, 0, headerTagFromHeader.length);
3333

3434
byte[] headerTag =
35-
floeKdf.hkdfExpand(floeKey, floeIv, floeAad, FloePurpose.HEADER_TAG, headerTagLength);
35+
keyDerivator.hkdfExpand(this.floeKey, floeIv, this.floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength);
3636
if (!Arrays.equals(headerTag, headerTagFromHeader)) {
3737
throw new IllegalArgumentException("invalid header tag");
3838
}
3939
}
40+
41+
@Override
42+
public byte[] processSegment(byte[] input) {
43+
try {
44+
verifySegmentLength(input);
45+
ByteBuffer inputBuf = ByteBuffer.wrap(input);
46+
verifySegmentSizeMarker(inputBuf);
47+
AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter);
48+
AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength());
49+
AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++);
50+
AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider();
51+
byte[] ciphertext = new byte[inputBuf.remaining()];
52+
inputBuf.get(ciphertext);
53+
return aeadProvider.decrypt(aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext);
54+
} catch (GeneralSecurityException e) {
55+
throw new RuntimeException(e);
56+
}
57+
}
58+
59+
private void verifySegmentLength(byte[] input) {
60+
if (input.length != parameterSpec.getEncryptedSegmentLength()) {
61+
throw new IllegalArgumentException(String.format("segment length mismatch, expected %d, got %d", parameterSpec.getEncryptedSegmentLength(), input.length));
62+
}
63+
}
64+
65+
private void verifySegmentSizeMarker(ByteBuffer inputBuf) {
66+
int segmentSizeMarker = inputBuf.getInt();
67+
if (segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER) {
68+
throw new IllegalStateException(String.format("segment length marker mismatch, expected: %d, got :%d", NON_TERMINAL_SEGMENT_SIZE_MARKER, segmentSizeMarker));
69+
}
70+
}
4071
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
package net.snowflake.client.jdbc.cloud.storage.floe;
22

3-
public interface FloeEncryptor {
3+
public interface FloeEncryptor extends SegmentProcessor {
44
byte[] getHeader();
55
}

0 commit comments

Comments
 (0)