Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable custom serialization #3789

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ private <Request extends ActionRequest, Response extends ActionResponse> void ap
}

if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) {
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false);
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, true);
}

final ComplianceConfig complianceConfig = auditLog.getComplianceConfig();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

import io.netty.handler.ssl.SslHandler;

import static org.opensearch.security.support.Base64Helper.shouldUseJDKSerialization;

public class SecuritySSLRequestHandler<T extends TransportRequest> implements TransportRequestHandler<T> {

private final String action;
Expand Down Expand Up @@ -94,10 +96,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task
channel = getInnerChannel(channel);
}

threadContext.putTransient(
ConfigConstants.USE_JDK_SERIALIZATION,
channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION)
);
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, shouldUseJDKSerialization(channel.getVersion()));

if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) {
final Exception exception = ExceptionUtils.createBadHeaderException();
Expand Down
34 changes: 32 additions & 2 deletions src/main/java/org/opensearch/security/support/Base64Helper.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,20 @@

import java.io.Serializable;

import org.opensearch.Version;

public class Base64Helper {

public static String serializeObject(final Serializable object, final boolean useJDKSerialization) {
return useJDKSerialization ? Base64JDKHelper.serializeObject(object) : Base64CustomHelper.serializeObject(object);
}

public static String serializeObject(final Serializable object) {
return serializeObject(object, false);
return serializeObject(object, true);
}

public static Serializable deserializeObject(final String string) {
return deserializeObject(string, false);
return deserializeObject(string, true);
}

public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) {
Expand Down Expand Up @@ -69,4 +71,32 @@ public static String ensureJDKSerialized(final String string) {
// If we see an exception now, we want the caller to see it -
return Base64Helper.serializeObject(serializable, true);
}

/**
* Ensures that the returned string is custom serialized.
*
* If the supplied string is a JDK serialized representation, will deserialize it and further serialize using
* custom, otherwise returns the string as is.
*
* @param string original string, can be JDK or custom serialized
* @return custom serialized string
*/
public static String ensureCustomSerialized(final String string) {
Serializable serializable;
try {
serializable = Base64Helper.deserializeObject(string, true);
} catch (Exception e) {
// We received an exception when de-serializing the given string. It is probably custom serialized.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have a check from the exception message to confirm that the cause of the exception was indeed deserialization failure using Custom serialization (this would allow to skip the check on line 91 and avoid unnecessary deserialization)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're trying to deserialize via JDK here. Depending on the exception message doesn't look like a good idea as it may vary across versions of JDK or a change in the underlying implementation of the Base64JDKHelper in future.

// Try to deserialize using custom
Base64Helper.deserializeObject(string, false);
// Since we could deserialize the object using custom, the string is already custom serialized, return as is
return string;
}
// If we see an exception now, we want the caller to see it -
return Base64Helper.serializeObject(serializable, false);
}

public static boolean shouldUseJDKSerialization(Version remoteVersion) {
return !remoteVersion.equals(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.opensearch.transport.TransportResponseHandler;

import static org.opensearch.security.OpenSearchSecurityPlugin.isActionTraceEnabled;
import static org.opensearch.security.support.Base64Helper.shouldUseJDKSerialization;

public class SecurityInterceptor {

Expand Down Expand Up @@ -148,7 +149,7 @@ public <T extends TransportResponse> void sendRequestDecorate(
final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS);

final boolean isDebugEnabled = log.isDebugEnabled();
final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
final boolean useJDKSerialization = shouldUseJDKSerialization(connection.getVersion());
final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode());

try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) {
Expand Down Expand Up @@ -226,13 +227,13 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
);
}

if (useJDKSerialization) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
if (!useJDKSerialization) {
Map<String, String> customSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
.forEach(k -> customSerializedHeaders.put(k, Base64Helper.ensureCustomSerialized(headerMap.get(k))));
headerMap.putAll(customSerializedHeaders);
}

getThreadContext().putHeader(headerMap);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ public void testSerde() {
String test = "string";
Assert.assertEquals(test, ds(test));
Assert.assertEquals(test, dsJDK(test));

// verify that default methods use JDK serialization
Assert.assertEquals(serializeObject(test), serializeObject(test, true));
String serialized = serializeObject(test);
Assert.assertEquals(deserializeObject(serialized), deserializeObject(serialized, true));
}

@Test
Expand All @@ -48,4 +53,13 @@ public void testEnsureJDKSerialized() {
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized));
Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized));
}

@Test
public void testEnsureCustomSerialized() {
String test = "string";
String jdkSerialized = Base64Helper.serializeObject(test, true);
String customSerialized = Base64Helper.serializeObject(test, false);
Assert.assertEquals(customSerialized, Base64Helper.ensureCustomSerialized(jdkSerialized));
Assert.assertEquals(customSerialized, Base64Helper.ensureCustomSerialized(customSerialized));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.extensions.ExtensionsManager;
Expand Down Expand Up @@ -108,8 +109,7 @@ public void setup() {
);
}

private void testSendRequestDecorate(Version remoteNodeVersion) {
boolean useJDKSerialization = remoteNodeVersion.before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
private void testSendRequestDecorate(DiscoveryNode localNode, DiscoveryNode otherNode, boolean shouldUseJDKSerialization) {
ClusterName clusterName = ClusterName.DEFAULT;
when(clusterService.getClusterName()).thenReturn(clusterName);

Expand Down Expand Up @@ -143,17 +143,7 @@ private void testSendRequestDecorate(Version remoteNodeVersion) {
@SuppressWarnings("unchecked")
TransportResponseHandler<TransportResponse> handler = mock(TransportResponseHandler.class);

InetAddress localAddress = null;
try {
localAddress = InetAddress.getByName("0.0.0.0");
} catch (final UnknownHostException uhe) {
throw new RuntimeException(uhe);
}

DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 1234), Version.CURRENT);
Connection connection1 = transportService.getConnection(localNode);

DiscoveryNode otherNode = new DiscoveryNode("remote-node", new TransportAddress(localAddress, 4321), remoteNodeVersion);
Connection connection2 = transportService.getConnection(otherNode);

// from thread context inside sendRequestDecorate
Expand Down Expand Up @@ -189,7 +179,7 @@ public <T extends TransportResponse> void sendRequest(
TransportResponseHandler<T> handler
) {
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, useJDKSerialization));
assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, shouldUseJDKSerialization));
}
};
// isSameNodeRequest = false
Expand All @@ -201,17 +191,49 @@ public <T extends TransportResponse> void sendRequest(
assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null);
}

/**
* Tests the scenario when remote node is on same OS version
*/
@Test
public void testSendRequestDecorate() {
testSendRequestDecorate(Version.CURRENT);
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT);
DiscoveryNode otherNode = new DiscoveryNode("other-node", new TransportAddress(getLocalAddress(), 3456), Version.CURRENT);
testSendRequestDecorate(localNode, otherNode, true);
}

/**
* Tests the scenario when remote node does not implement custom serialization protocol and uses JDK serialization
* Tests the scenarios for mixed node versions
*/
@Test
public void testSendRequestDecorateWhenRemoteNodeUsesJDKSerde() {
testSendRequestDecorate(Version.V_2_0_0);
public void testSendRequestDecorateWithMixedNodeVersions() {

// local on latest version, remote on 2.11.0 - should use custom

try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT);
DiscoveryNode otherNode = new DiscoveryNode(
"other-node",
new TransportAddress(getLocalAddress(), 3456),
ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION
);
testSendRequestDecorate(localNode, otherNode, false);
}

// remote node is on a version > 2.11.1 while local node is on version 2.11.1 - should use JDK
try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) {
DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT);
DiscoveryNode otherNode = new DiscoveryNode("other-node", new TransportAddress(getLocalAddress(), 3456), Version.V_2_11_1);
testSendRequestDecorate(localNode, otherNode, true);
}

}

private static InetAddress getLocalAddress() {
try {
return InetAddress.getByName("0.0.0.0");
} catch (final UnknownHostException uhe) {
throw new RuntimeException(uhe);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,15 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
when(transportChannel.getVersion()).thenReturn(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
when(transportChannel.getVersion()).thenReturn(Version.CURRENT);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

}

@Test
Expand All @@ -113,14 +114,14 @@ public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Except
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0);
when(transportChannel.getVersion()).thenReturn(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));

threadPool.getThreadContext().stashContext();
when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0);
when(transportChannel.getVersion()).thenReturn(Version.CURRENT);
Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task));
Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION));
}

@Test
Expand Down