Skip to content

Commit 68b11d2

Browse files
committed
Implement processing segments
1 parent 2634bdd commit 68b11d2

22 files changed

+564
-115
lines changed

src/main/java/net/snowflake/client/jdbc/cloud/storage/GcmEncryptionProvider.java

+8-8
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@
2626
import javax.crypto.SecretKey;
2727
import javax.crypto.spec.GCMParameterSpec;
2828
import javax.crypto.spec.SecretKeySpec;
29+
import net.snowflake.client.core.SnowflakeJdbcInternalApi;
2930
import net.snowflake.client.jdbc.MatDesc;
3031
import net.snowflake.common.core.RemoteStoreFileEncryptionMaterial;
3132

32-
class GcmEncryptionProvider {
33+
@SnowflakeJdbcInternalApi
34+
public class GcmEncryptionProvider {
3335
private static final int TAG_LENGTH_IN_BITS = 128;
3436
private static final int IV_LENGTH_IN_BYTES = 12;
3537
private static final String AES = "AES";
36-
private static final String FILE_CIPHER = "AES/GCM/NoPadding";
37-
private static final String KEY_CIPHER = "AES/GCM/NoPadding";
3838
private static final int BUFFER_SIZE = 8 * 1024 * 1024; // 2 MB
3939
private static final ThreadLocal<SecureRandom> random =
4040
ThreadLocal.withInitial(SecureRandom::new);
@@ -85,7 +85,7 @@ private static byte[] encryptKey(byte[] kekBytes, byte[] keyBytes, byte[] keyIvD
8585
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
8686
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
8787
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, keyIvData);
88-
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
88+
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
8989
keyCipher.init(Cipher.ENCRYPT_MODE, kek, gcmParameterSpec);
9090
if (aad != null) {
9191
keyCipher.updateAAD(aad);
@@ -99,7 +99,7 @@ private static CipherInputStream encryptContent(
9999
NoSuchAlgorithmException {
100100
SecretKey fileKey = new SecretKeySpec(keyBytes, 0, keyBytes.length, AES);
101101
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, dataIvBytes);
102-
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
102+
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
103103
fileCipher.init(Cipher.ENCRYPT_MODE, fileKey, gcmParameterSpec);
104104
if (aad != null) {
105105
fileCipher.updateAAD(aad);
@@ -172,7 +172,7 @@ private static CipherInputStream decryptContentFromStream(
172172
NoSuchAlgorithmException {
173173
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
174174
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
175-
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
175+
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
176176
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
177177
if (aad != null) {
178178
fileCipher.updateAAD(aad);
@@ -187,7 +187,7 @@ private static void decryptContentFromFile(
187187
SecretKey fileKey = new SecretKeySpec(fileKeyBytes, AES);
188188
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, cekIvBytes);
189189
byte[] buffer = new byte[BUFFER_SIZE];
190-
Cipher fileCipher = Cipher.getInstance(FILE_CIPHER);
190+
Cipher fileCipher = Cipher.getInstance(JCE_CIPHER_NAME);
191191
fileCipher.init(Cipher.DECRYPT_MODE, fileKey, gcmParameterSpec);
192192
if (aad != null) {
193193
fileCipher.updateAAD(aad);
@@ -215,7 +215,7 @@ private static byte[] decryptKey(byte[] kekBytes, byte[] ivBytes, byte[] keyByte
215215
BadPaddingException, NoSuchPaddingException, NoSuchAlgorithmException {
216216
SecretKey kek = new SecretKeySpec(kekBytes, 0, kekBytes.length, AES);
217217
GCMParameterSpec gcmParameterSpec = new GCMParameterSpec(TAG_LENGTH_IN_BITS, ivBytes);
218-
Cipher keyCipher = Cipher.getInstance(KEY_CIPHER);
218+
Cipher keyCipher = Cipher.getInstance(JCE_CIPHER_NAME);
219219
keyCipher.init(Cipher.DECRYPT_MODE, kek, gcmParameterSpec);
220220
if (aad != null) {
221221
keyCipher.updateAAD(aad);
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,55 @@
11
package net.snowflake.client.jdbc.cloud.storage.floe;
22

3+
import net.snowflake.client.jdbc.cloud.storage.floe.aead.Gcm;
4+
35
public enum Aead {
4-
AES_GCM_128((byte) 0),
5-
AES_GCM_256((byte) 1);
6+
// TODO confirm id
7+
AES_GCM_256((byte) 0, "AES/GCM/NoPadding", 32, 12, 16, new Gcm(16)),
8+
AES_GCM_128((byte) 1, "AES/GCM/NoPadding", 16, 12, 16, new Gcm(16));
69

710
private byte id;
11+
private String jceName;
12+
private int keyLength;
13+
private int ivLength;
14+
private int authTagLength;
15+
private AeadProvider aeadProvider;
816

9-
Aead(byte id) {
17+
Aead(
18+
byte id,
19+
String jceName,
20+
int keyLength,
21+
int ivLength,
22+
int authTagLength,
23+
AeadProvider aeadProvider) {
24+
this.jceName = jceName;
25+
this.keyLength = keyLength;
1026
this.id = id;
27+
this.ivLength = ivLength;
28+
this.authTagLength = authTagLength;
29+
this.aeadProvider = aeadProvider;
1130
}
1231

1332
byte getId() {
1433
return id;
1534
}
35+
36+
String getJceName() {
37+
return jceName;
38+
}
39+
40+
int getKeyLength() {
41+
return keyLength;
42+
}
43+
44+
int getIvLength() {
45+
return ivLength;
46+
}
47+
48+
int getAuthTagLength() {
49+
return authTagLength;
50+
}
51+
52+
AeadProvider getAeadProvider() {
53+
return aeadProvider;
54+
}
1655
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import java.nio.ByteBuffer;
4+
5+
class AeadAad {
6+
private final byte[] bytes;
7+
8+
private AeadAad(long segmentCounter, byte terminalityByte) {
9+
ByteBuffer buf = ByteBuffer.allocate(9);
10+
buf.putLong(segmentCounter);
11+
buf.put(terminalityByte);
12+
this.bytes = buf.array();
13+
}
14+
15+
static AeadAad nonTerminal(long segmentCounter) {
16+
return new AeadAad(segmentCounter, (byte) 0);
17+
}
18+
19+
byte[] getBytes() {
20+
return bytes;
21+
}
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import java.nio.ByteBuffer;
4+
5+
class AeadIv {
6+
private final byte[] bytes;
7+
8+
AeadIv(byte[] bytes) {
9+
this.bytes = bytes;
10+
}
11+
12+
public static AeadIv generateRandom(FloeRandom floeRandom, int ivLength) {
13+
return new AeadIv(floeRandom.ofLength(ivLength));
14+
}
15+
16+
public static AeadIv from(ByteBuffer buffer, int ivLength) {
17+
byte[] bytes = new byte[ivLength];
18+
buffer.get(bytes);
19+
return new AeadIv(bytes);
20+
}
21+
22+
byte[] getBytes() {
23+
return bytes;
24+
}
25+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import javax.crypto.SecretKey;
4+
5+
class AeadKey {
6+
private final SecretKey key;
7+
8+
AeadKey(SecretKey key) {
9+
this.key = key;
10+
}
11+
12+
SecretKey getKey() {
13+
return key;
14+
}
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import java.security.GeneralSecurityException;
4+
import javax.crypto.SecretKey;
5+
6+
public interface AeadProvider {
7+
byte[] encrypt(SecretKey key, byte[] iv, byte[] aad, byte[] plaintext)
8+
throws GeneralSecurityException;
9+
10+
byte[] decrypt(SecretKey key, byte[] iv, byte[] aad, byte[] ciphertext)
11+
throws GeneralSecurityException;
12+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package net.snowflake.client.jdbc.cloud.storage.floe;
2+
3+
import javax.crypto.SecretKey;
4+
import javax.crypto.spec.SecretKeySpec;
5+
6+
abstract class BaseSegmentProcessor {
7+
protected static final int NON_TERMINAL_SEGMENT_SIZE_MARKER = -1;
8+
protected static final int headerTagLength = 32;
9+
10+
protected final FloeParameterSpec parameterSpec;
11+
protected final FloeKey floeKey;
12+
protected final FloeAad floeAad;
13+
14+
protected final KeyDerivator keyDerivator;
15+
16+
private AeadKey currentAeadKey;
17+
18+
BaseSegmentProcessor(FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad) {
19+
this.parameterSpec = parameterSpec;
20+
this.floeKey = floeKey;
21+
this.floeAad = floeAad;
22+
this.keyDerivator = new KeyDerivator(parameterSpec);
23+
}
24+
25+
protected AeadKey getKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
26+
if (currentAeadKey == null || segmentCounter % parameterSpec.getKeyRotationModulo() == 0) {
27+
currentAeadKey = deriveKey(floeKey, floeIv, floeAad, segmentCounter);
28+
}
29+
return currentAeadKey;
30+
}
31+
32+
private AeadKey deriveKey(FloeKey floeKey, FloeIv floeIv, FloeAad floeAad, long segmentCounter) {
33+
byte[] keyBytes =
34+
keyDerivator.hkdfExpand(
35+
floeKey,
36+
floeIv,
37+
floeAad,
38+
new DekTagFloePurpose(segmentCounter),
39+
parameterSpec.getAead().getKeyLength());
40+
SecretKey key =
41+
new SecretKeySpec(keyBytes, "AES"); // for now it is safe as we use only AES as AEAD
42+
return new AeadKey(key);
43+
}
44+
}

src/main/java/net/snowflake/client/jdbc/cloud/storage/floe/FloeBase.java

-18
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
package net.snowflake.client.jdbc.cloud.storage.floe;
22

3-
public interface FloeDecryptor {}
3+
public interface FloeDecryptor extends SegmentProcessor {}
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,21 @@
11
package net.snowflake.client.jdbc.cloud.storage.floe;
22

33
import java.nio.ByteBuffer;
4+
import java.security.GeneralSecurityException;
45
import java.util.Arrays;
56

6-
public class FloeDecryptorImpl extends FloeBase implements FloeDecryptor {
7+
public class FloeDecryptorImpl extends BaseSegmentProcessor implements FloeDecryptor {
8+
private final FloeIv floeIv;
9+
private long segmentCounter;
10+
711
FloeDecryptorImpl(
812
FloeParameterSpec parameterSpec, FloeKey floeKey, FloeAad floeAad, byte[] floeHeaderAsBytes) {
913
super(parameterSpec, floeKey, floeAad);
10-
validate(floeHeaderAsBytes);
11-
}
12-
13-
public void validate(byte[] floeHeaderAsBytes) {
14-
byte[] encodedParams = parameterSpec.paramEncode();
14+
byte[] encodedParams = this.parameterSpec.paramEncode();
1515
if (floeHeaderAsBytes.length
16-
!= encodedParams.length + parameterSpec.getFloeIvLength().getLength() + headerTagLength) {
16+
!= encodedParams.length
17+
+ this.parameterSpec.getFloeIvLength().getLength()
18+
+ headerTagLength) {
1719
throw new IllegalArgumentException("invalid header length");
1820
}
1921
ByteBuffer floeHeader = ByteBuffer.wrap(floeHeaderAsBytes);
@@ -24,17 +26,56 @@ public void validate(byte[] floeHeaderAsBytes) {
2426
throw new IllegalArgumentException("invalid parameters header");
2527
}
2628

27-
byte[] floeIvBytes = new byte[parameterSpec.getFloeIvLength().getLength()];
29+
byte[] floeIvBytes = new byte[this.parameterSpec.getFloeIvLength().getLength()];
2830
floeHeader.get(floeIvBytes, 0, floeIvBytes.length);
29-
FloeIv floeIv = new FloeIv(floeIvBytes);
31+
this.floeIv = new FloeIv(floeIvBytes);
3032

3133
byte[] headerTagFromHeader = new byte[headerTagLength];
3234
floeHeader.get(headerTagFromHeader, 0, headerTagFromHeader.length);
3335

3436
byte[] headerTag =
35-
floeKdf.hkdfExpand(floeKey, floeIv, floeAad, FloePurpose.HEADER_TAG, headerTagLength);
37+
keyDerivator.hkdfExpand(
38+
this.floeKey, floeIv, this.floeAad, HeaderTagFloePurpose.INSTANCE, headerTagLength);
3639
if (!Arrays.equals(headerTag, headerTagFromHeader)) {
3740
throw new IllegalArgumentException("invalid header tag");
3841
}
3942
}
43+
44+
@Override
45+
public byte[] processSegment(byte[] input) {
46+
try {
47+
verifySegmentLength(input);
48+
ByteBuffer inputBuf = ByteBuffer.wrap(input);
49+
verifySegmentSizeMarker(inputBuf);
50+
AeadKey aeadKey = getKey(floeKey, floeIv, floeAad, segmentCounter);
51+
AeadIv aeadIv = AeadIv.from(inputBuf, parameterSpec.getAead().getIvLength());
52+
AeadAad aeadAad = AeadAad.nonTerminal(segmentCounter++);
53+
AeadProvider aeadProvider = parameterSpec.getAead().getAeadProvider();
54+
byte[] ciphertext = new byte[inputBuf.remaining()];
55+
inputBuf.get(ciphertext);
56+
return aeadProvider.decrypt(
57+
aeadKey.getKey(), aeadIv.getBytes(), aeadAad.getBytes(), ciphertext);
58+
} catch (GeneralSecurityException e) {
59+
throw new RuntimeException(e);
60+
}
61+
}
62+
63+
private void verifySegmentLength(byte[] input) {
64+
if (input.length != parameterSpec.getEncryptedSegmentLength()) {
65+
throw new IllegalArgumentException(
66+
String.format(
67+
"segment length mismatch, expected %d, got %d",
68+
parameterSpec.getEncryptedSegmentLength(), input.length));
69+
}
70+
}
71+
72+
private void verifySegmentSizeMarker(ByteBuffer inputBuf) {
73+
int segmentSizeMarker = inputBuf.getInt();
74+
if (segmentSizeMarker != NON_TERMINAL_SEGMENT_SIZE_MARKER) {
75+
throw new IllegalStateException(
76+
String.format(
77+
"segment length marker mismatch, expected: %d, got :%d",
78+
NON_TERMINAL_SEGMENT_SIZE_MARKER, segmentSizeMarker));
79+
}
80+
}
4081
}
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
package net.snowflake.client.jdbc.cloud.storage.floe;
22

3-
public interface FloeEncryptor {
3+
public interface FloeEncryptor extends SegmentProcessor {
44
byte[] getHeader();
55
}

0 commit comments

Comments
 (0)