Skip to content

Commit 9a25624

Browse files
author
Paras Jain
committed
adds ability to enable/disable custom serialization
Signed-off-by: Paras Jain <[email protected]>
1 parent 87de7e2 commit 9a25624

File tree

7 files changed

+97
-31
lines changed

7 files changed

+97
-31
lines changed

src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646

4747
import io.netty.handler.ssl.SslHandler;
4848

49+
import static org.opensearch.security.support.Base64Helper.shouldUseJDKSerialization;
50+
4951
public class SecuritySSLRequestHandler<T extends TransportRequest> implements TransportRequestHandler<T> {
5052

5153
private final String action;
@@ -96,7 +98,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task
9698

9799
threadContext.putTransient(
98100
ConfigConstants.USE_JDK_SERIALIZATION,
99-
channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
101+
shouldUseJDKSerialization(channel.getVersion())
100102
);
101103

102104
if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) {

src/main/java/org/opensearch/security/support/Base64Helper.java

+32-2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
package org.opensearch.security.support;
2828

29+
import org.opensearch.Version;
30+
2931
import java.io.Serializable;
3032

3133
public class Base64Helper {
@@ -35,11 +37,11 @@ public static String serializeObject(final Serializable object, final boolean us
3537
}
3638

3739
public static String serializeObject(final Serializable object) {
38-
return serializeObject(object, false);
40+
return serializeObject(object, true);
3941
}
4042

4143
public static Serializable deserializeObject(final String string) {
42-
return deserializeObject(string, false);
44+
return deserializeObject(string, true);
4345
}
4446

4547
public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) {
@@ -69,4 +71,32 @@ public static String ensureJDKSerialized(final String string) {
6971
// If we see an exception now, we want the caller to see it -
7072
return Base64Helper.serializeObject(serializable, true);
7173
}
74+
75+
/**
76+
* Ensures that the returned string is custom serialized.
77+
*
78+
* If the supplied string is a JDK serialized representation, will deserialize it and further serialize using
79+
* custom, otherwise returns the string as is.
80+
*
81+
* @param string original string, can be JDK or custom serialized
82+
* @return custom serialized string
83+
*/
84+
public static String ensureCustomSerialized(final String string) {
85+
Serializable serializable;
86+
try {
87+
serializable = Base64Helper.deserializeObject(string, true);
88+
} catch (Exception e) {
89+
// We received an exception when de-serializing the given string. It is probably custom serialized.
90+
// Try to deserialize using custom
91+
Base64Helper.deserializeObject(string, false);
92+
// Since we could deserialize the object using custom, the string is already custom serialized, return as is
93+
return string;
94+
}
95+
// If we see an exception now, we want the caller to see it -
96+
return Base64Helper.serializeObject(serializable, false);
97+
}
98+
99+
public static boolean shouldUseJDKSerialization(Version remoteVersion) {
100+
return ! remoteVersion.equals(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
101+
}
72102
}

src/main/java/org/opensearch/security/support/ConfigConstants.java

+1
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ public enum RolesMappingResolution {
334334
public static final boolean EXTENSIONS_BWC_PLUGIN_MODE_DEFAULT = false;
335335
// CS-ENFORCE-SINGLE
336336

337+
337338
public static Set<String> getSettingAsSet(
338339
final Settings settings,
339340
final String key,

src/main/java/org/opensearch/security/transport/SecurityInterceptor.java

+6-5
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
import org.opensearch.transport.TransportResponseHandler;
7373

7474
import static org.opensearch.security.OpenSearchSecurityPlugin.isActionTraceEnabled;
75+
import static org.opensearch.security.support.Base64Helper.shouldUseJDKSerialization;
7576

7677
public class SecurityInterceptor {
7778

@@ -148,7 +149,7 @@ public <T extends TransportResponse> void sendRequestDecorate(
148149
final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS);
149150

150151
final boolean isDebugEnabled = log.isDebugEnabled();
151-
final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
152+
final boolean useJDKSerialization = shouldUseJDKSerialization(connection.getVersion());
152153
final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode());
153154

154155
try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) {
@@ -226,13 +227,13 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
226227
);
227228
}
228229

229-
if (useJDKSerialization) {
230-
Map<String, String> jdkSerializedHeaders = new HashMap<>();
230+
if (!useJDKSerialization) {
231+
Map<String, String> customSerializedHeaders = new HashMap<>();
231232
HeaderHelper.getAllSerializedHeaderNames()
232233
.stream()
233234
.filter(k -> headerMap.get(k) != null)
234-
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
235-
headerMap.putAll(jdkSerializedHeaders);
235+
.forEach(k -> customSerializedHeaders.put(k, Base64Helper.ensureCustomSerialized(headerMap.get(k))));
236+
headerMap.putAll(customSerializedHeaders);
236237
}
237238

238239
getThreadContext().putHeader(headerMap);

src/test/java/org/opensearch/security/support/Base64HelperTest.java

+10
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,14 @@ public void testEnsureJDKSerialized() {
4848
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized));
4949
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized));
5050
}
51+
52+
53+
@Test
54+
public void testEnsureCustomSerialized() {
55+
String test = "string";
56+
String jdkSerialized = Base64Helper.serializeObject(test, true);
57+
String customSerialized = Base64Helper.serializeObject(test, false);
58+
Assert.assertEquals(customSerialized, Base64Helper.ensureCustomSerialized(jdkSerialized));
59+
Assert.assertEquals(customSerialized, Base64Helper.ensureCustomSerialized(customSerialized));
60+
}
5161
}

src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java

+38-17
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.opensearch.cluster.node.DiscoveryNode;
2222
import org.opensearch.cluster.service.ClusterService;
2323
import org.opensearch.common.settings.Settings;
24+
import org.opensearch.common.util.concurrent.ThreadContext;
2425
import org.opensearch.core.common.transport.TransportAddress;
2526
import org.opensearch.core.transport.TransportResponse;
2627
import org.opensearch.extensions.ExtensionsManager;
@@ -108,8 +109,7 @@ public void setup() {
108109
);
109110
}
110111

111-
private void testSendRequestDecorate(Version remoteNodeVersion) {
112-
boolean useJDKSerialization = remoteNodeVersion.before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
112+
private void testSendRequestDecorate(DiscoveryNode localNode, DiscoveryNode otherNode, boolean shouldUseJDKSerialization) {
113113
ClusterName clusterName = ClusterName.DEFAULT;
114114
when(clusterService.getClusterName()).thenReturn(clusterName);
115115

@@ -143,17 +143,7 @@ private void testSendRequestDecorate(Version remoteNodeVersion) {
143143
@SuppressWarnings("unchecked")
144144
TransportResponseHandler<TransportResponse> handler = mock(TransportResponseHandler.class);
145145

146-
InetAddress localAddress = null;
147-
try {
148-
localAddress = InetAddress.getByName("0.0.0.0");
149-
} catch (final UnknownHostException uhe) {
150-
throw new RuntimeException(uhe);
151-
}
152-
153-
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 1234), Version.CURRENT);
154146
Connection connection1 = transportService.getConnection(localNode);
155-
156-
DiscoveryNode otherNode = new DiscoveryNode("remote-node", new TransportAddress(localAddress, 4321), remoteNodeVersion);
157147
Connection connection2 = transportService.getConnection(otherNode);
158148

159149
// from thread context inside sendRequestDecorate
@@ -189,7 +179,7 @@ public <T extends TransportResponse> void sendRequest(
189179
TransportResponseHandler<T> handler
190180
) {
191181
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
192-
assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, useJDKSerialization));
182+
assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, shouldUseJDKSerialization));
193183
}
194184
};
195185
// isSameNodeRequest = false
@@ -201,17 +191,48 @@ public <T extends TransportResponse> void sendRequest(
201191
assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null);
202192
}
203193

194+
195+
/**
196+
* Tests the scenario when remote node is on same OS version
197+
*/
204198
@Test
205199
public void testSendRequestDecorate() {
206-
testSendRequestDecorate(Version.CURRENT);
200+
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT);
201+
DiscoveryNode otherNode = new DiscoveryNode("other-node", new TransportAddress(getLocalAddress() ,3456), Version.CURRENT);
202+
testSendRequestDecorate(localNode, otherNode, true);
207203
}
208204

209205
/**
210-
* Tests the scenario when remote node does not implement custom serialization protocol and uses JDK serialization
206+
* Tests the scenarios for mixed node versions
211207
*/
212208
@Test
213-
public void testSendRequestDecorateWhenRemoteNodeUsesJDKSerde() {
214-
testSendRequestDecorate(Version.V_2_0_0);
209+
public void testSendRequestDecorateWithMixedNodeVersions() {
210+
211+
// local on latest version, remote on 2.11.0 - should use custom
212+
213+
try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
214+
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT);
215+
DiscoveryNode otherNode = new DiscoveryNode("other-node", new TransportAddress(getLocalAddress(), 3456), ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
216+
testSendRequestDecorate(localNode, otherNode, false);
217+
}
218+
219+
// remote node is on a version > 2.11.1 while local node is on version 2.11.1 - should use JDK
220+
try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
221+
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT);
222+
DiscoveryNode otherNode = new DiscoveryNode("other-node", new TransportAddress(getLocalAddress(), 3456), Version.V_2_11_1);
223+
testSendRequestDecorate(localNode, otherNode, true);
224+
}
225+
215226
}
216227

228+
229+
private static InetAddress getLocalAddress() {
230+
try {
231+
return InetAddress.getByName("0.0.0.0");
232+
} catch (final UnknownHostException uhe) {
233+
throw new RuntimeException(uhe);
234+
}
235+
}
236+
237+
217238
}

src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java

+7-6
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,15 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti
8989
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
9090

9191
threadPool.getThreadContext().stashContext();
92-
when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
92+
when(transportChannel.getVersion()).thenReturn(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
9393
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
9494
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
9595

9696
threadPool.getThreadContext().stashContext();
97-
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
97+
when(transportChannel.getVersion()).thenReturn(Version.CURRENT);
9898
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
99-
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
99+
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
100+
100101
}
101102

102103
@Test
@@ -113,14 +114,14 @@ public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Except
113114
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
114115

115116
threadPool.getThreadContext().stashContext();
116-
when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
117+
when(transportChannel.getVersion()).thenReturn(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
117118
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
118119
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
119120

120121
threadPool.getThreadContext().stashContext();
121-
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
122+
when(transportChannel.getVersion()).thenReturn(Version.CURRENT);
122123
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
123-
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
124+
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
124125
}
125126

126127
@Test

0 commit comments

Comments
 (0)