diff --git a/core/src/main/java/io/grpc/internal/CertificateUtils.java b/core/src/main/java/io/grpc/internal/CertificateUtils.java index cc26cedb274..ef66a9ed1f1 100644 --- a/core/src/main/java/io/grpc/internal/CertificateUtils.java +++ b/core/src/main/java/io/grpc/internal/CertificateUtils.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.security.GeneralSecurityException; @@ -32,12 +33,25 @@ /** * Contains certificate/key PEM file utility method(s) for internal usage. */ -public final class CertificateUtils { +public class CertificateUtils { /** * Creates X509TrustManagers using the provided CA certs. */ - public static TrustManager[] createTrustManager(InputStream rootCerts) + public static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException { + InputStream rootCertsStream = new ByteArrayInputStream(rootCerts); + try { + return CertificateUtils.createTrustManager(rootCertsStream); + } finally { + GrpcUtil.closeQuietly(rootCertsStream); + } + } + + /** + * Creates X509TrustManagers using the provided input stream of CA certs. + */ + public static TrustManager[] createTrustManager(InputStream rootCerts) + throws GeneralSecurityException { KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); try { ks.load(null, null); @@ -52,13 +66,13 @@ public static TrustManager[] createTrustManager(InputStream rootCerts) } TrustManagerFactory trustManagerFactory = - TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(ks); return trustManagerFactory.getTrustManagers(); } private static X509Certificate[] getX509Certificates(InputStream inputStream) - throws CertificateException { + throws CertificateException { CertificateFactory factory = CertificateFactory.getInstance("X.509"); Collection certs = factory.generateCertificates(inputStream); return certs.toArray(new X509Certificate[0]); diff --git a/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java b/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java new file mode 100644 index 00000000000..6e6a6f12a39 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java @@ -0,0 +1,117 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import java.io.IOException; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; + +/** A no-op ssl socket, to facilitate overriding only the required methods in specific + * implementations. + */ +class NoopSslSocket extends SSLSocket { + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void addHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void startHandshake() throws IOException { + + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 7eaaa6fd763..98f764132fe 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.internal.CertificateUtils.createTrustManager; import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; @@ -89,6 +90,7 @@ public final class OkHttpChannelBuilder extends ForwardingChannelBuilder2 ERROR_CODE_TO_STATUS = buildErrorCodeToStatusMap(); private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); + private static final String GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK = + "GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK"; + static boolean enablePerRpcAuthorityCheck = + GrpcUtil.getFlag(GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK, false); + private final ChannelCredentials channelCredentials; + private Socket sock; + private SSLSession sslSession; private static Map buildErrorCodeToStatusMap() { Map errorToStatus = new EnumMap<>(ErrorCode.class); @@ -144,6 +167,22 @@ private static Map buildErrorCodeToStatusMap() { return Collections.unmodifiableMap(errorToStatus); } + private static Class x509ExtendedTrustManagerClass; + private static Method checkServerTrustedMethod; + + static { + try { + x509ExtendedTrustManagerClass = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + checkServerTrustedMethod = x509ExtendedTrustManagerClass.getMethod("checkServerTrusted", + X509Certificate[].class, String.class, Socket.class); + } catch (ClassNotFoundException e) { + // Per-rpc authority override via call options will be disallowed. + } catch (NoSuchMethodException e) { + // Should never happen since X509ExtendedTrustManager was introduced in Android API level 24 + // along with checkServerTrusted. + } + } + private final InetSocketAddress address; private final String defaultAuthority; private final String userAgent; @@ -205,6 +244,21 @@ private static Map buildErrorCodeToStatusMap() { private final boolean useGetForSafeMethods; @GuardedBy("lock") private final TransportTracer transportTracer; + private final TrustManager x509ExtendedTrustManager; + + @SuppressWarnings("serial") + private static class LruCache extends LinkedHashMap { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > 100; + } + } + + private final Map hostnameVerificationResults = + Collections.synchronizedMap(new LruCache<>()); + private final Map peerVerificationResults = + Collections.synchronizedMap(new LruCache<>()); + @GuardedBy("lock") private final InUseStateAggregator inUseState = new InUseStateAggregator() { @@ -233,13 +287,14 @@ protected void handleNotInUse() { SettableFuture connectedFuture; public OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, - InetSocketAddress address, - String authority, - @Nullable String userAgent, - Attributes eagAttrs, - @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable) { + OkHttpTransportFactory transportFactory, + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + ChannelCredentials channelCredentials) { this( transportFactory, address, @@ -249,19 +304,21 @@ public OkHttpClientTransport( GrpcUtil.STOPWATCH_SUPPLIER, new Http2(), proxiedAddr, - tooManyPingsRunnable); + tooManyPingsRunnable, + channelCredentials); } private OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, - InetSocketAddress address, - String authority, - @Nullable String userAgent, - Attributes eagAttrs, - Supplier stopwatchFactory, - Variant variant, - @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable) { + OkHttpTransportFactory transportFactory, + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + Supplier stopwatchFactory, + Variant variant, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + ChannelCredentials channelCredentials) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; this.maxMessageSize = transportFactory.maxMessageSize; @@ -291,7 +348,23 @@ private OkHttpClientTransport( this.attributes = Attributes.newBuilder() .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build(); this.useGetForSafeMethods = transportFactory.useGetForSafeMethods; + this.channelCredentials = channelCredentials; initTransportTracer(); + TrustManager tempX509ExtendedTrustManager; + if (channelCredentials instanceof TlsChannelCredentials + && x509ExtendedTrustManagerClass != null) { + try { + tempX509ExtendedTrustManager = getX509ExtendedTrustManager( + (TlsChannelCredentials) channelCredentials); + } catch (GeneralSecurityException e) { + tempX509ExtendedTrustManager = null; + log.log(Level.WARNING, "Obtaining X509ExtendedTrustManager for the transport failed." + + "Per-rpc authority overrides will be disallowed.", e); + } + } else { + tempX509ExtendedTrustManager = null; + } + x509ExtendedTrustManager = tempX509ExtendedTrustManager; } /** @@ -316,7 +389,8 @@ private OkHttpClientTransport( stopwatchFactory, variant, null, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); this.connectingCallback = connectingCallback; this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture"); } @@ -396,6 +470,7 @@ public OkHttpClientStream newStream( Preconditions.checkNotNull(headers, "headers"); StatsTraceContext statsTraceContext = StatsTraceContext.newClientContext(tracers, getAttributes(), headers); + // FIXME: it is likely wrong to pass the transportTracer here as it'll exit the lock's scope synchronized (lock) { // to make @GuardedBy linter happy return new OkHttpClientStream( @@ -416,8 +491,30 @@ public OkHttpClientStream newStream( } } + private TrustManager getX509ExtendedTrustManager(TlsChannelCredentials tlsCreds) + throws GeneralSecurityException { + TrustManager[] tm = null; + // Using the same way of creating TrustManager from OkHttpChannelBuilder.sslSocketFactoryFrom() + if (tlsCreds.getTrustManagers() != null) { + tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); + } else if (tlsCreds.getRootCertificates() != null) { + tm = CertificateUtils.createTrustManager(tlsCreds.getRootCertificates()); + } else { // else use system default + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init((KeyStore) null); + tm = tmf.getTrustManagers(); + } + for (TrustManager trustManager: tm) { + if (x509ExtendedTrustManagerClass.isInstance(trustManager)) { + return trustManager; + } + } + return null; + } + @GuardedBy("lock") - void streamReadyToStart(OkHttpClientStream clientStream) { + void streamReadyToStart(OkHttpClientStream clientStream, String authority) { if (goAwayStatus != null) { clientStream.transportState().transportReportStatus( goAwayStatus, RpcProgress.MISCARRIED, true, new Metadata()); @@ -425,6 +522,80 @@ void streamReadyToStart(OkHttpClientStream clientStream) { pendingStreams.add(clientStream); setInUse(clientStream); } else { + if (hostnameVerifier != null && socket instanceof SSLSocket) { + Boolean hostnameVerificationResult; + if (hostnameVerificationResults.containsKey(authority)) { + hostnameVerificationResult = hostnameVerificationResults.get(authority); + } else { + hostnameVerificationResult = hostnameVerifier.verify( + authority, ((SSLSocket) socket).getSession()); + hostnameVerificationResults.put(authority, hostnameVerificationResult); + } + if (!hostnameVerificationResult) { + if (enablePerRpcAuthorityCheck) { + clientStream.transportState().transportReportStatus( + Status.UNAVAILABLE.withDescription( + String.format("HostNameVerifier verification failed for authority '%s'", + authority)), + RpcProgress.DROPPED, true, new Metadata()); + return; + } else { + log.log(Level.WARNING, String.format("HostNameVerifier verification failed for " + + "authority '%s'. This will be an error in the future.", + authority)); + } + } + } + if (socket instanceof SSLSocket && authority != null + && channelCredentials != null + && channelCredentials instanceof TlsChannelCredentials) { + Status peerVerificationStatus = null; + if (peerVerificationResults.containsKey(authority)) { + peerVerificationStatus = peerVerificationResults.get(authority); + } else { + if (x509ExtendedTrustManager == null) { + peerVerificationStatus = Status.UNAVAILABLE.withDescription( + String.format("Could not verify authority '%s' for the rpc with no " + + "X509ExtendedTrustManager available", + authority)); + } + if (x509ExtendedTrustManager != null) { + try { + Certificate[] peerCertificates = sslSession.getPeerCertificates(); + X509Certificate[] x509PeerCertificates = new X509Certificate[peerCertificates.length]; + for (int i = 0; i < peerCertificates.length; i++) { + x509PeerCertificates[i] = (X509Certificate) peerCertificates[i]; + } + checkServerTrustedMethod.invoke(x509ExtendedTrustManager, x509PeerCertificates, + "RSA", new SslSocketWrapper((SSLSocket) socket, authority)); + peerVerificationStatus = Status.OK; + } catch (SSLPeerUnverifiedException | InvocationTargetException + | IllegalAccessException e) { + peerVerificationStatus = Status.UNAVAILABLE.withCause(e).withDescription( + "Peer verification failed"); + } + if (peerVerificationStatus != null) { + peerVerificationResults.put(authority, peerVerificationStatus); + } + } + } + if (peerVerificationStatus != null && !peerVerificationStatus.isOk()) { + if (enablePerRpcAuthorityCheck) { + clientStream.transportState().transportReportStatus( + peerVerificationStatus, RpcProgress.DROPPED, true, new Metadata()); + return; + } else { + if (peerVerificationStatus.getCause() != null) { + log.log(Level.WARNING, peerVerificationStatus.getDescription() + + ". This will be an error in the future.", + peerVerificationStatus.getCause()); + } else { + log.log(Level.WARNING, peerVerificationStatus.getDescription() + + ". This will be an error in the future."); + } + } + } + } startStream(clientStream); } } @@ -531,8 +702,6 @@ public Timeout timeout() { public void close() { } }); - Socket sock; - SSLSession sslSession = null; try { // This is a hack to make sure the connection preface and initial settings to be sent out // without blocking the start. By doing this essentially prevents potential deadlock when @@ -1460,4 +1629,50 @@ public void alternateService(int streamId, String origin, ByteString protocol, S // TODO(madongfly): Deal with alternateService propagation } } + + /** + * SSLSocket wrapper that provides a fake SSLSession for handshake session. + */ + static final class SslSocketWrapper extends NoopSslSocket { + + private final SSLSession sslSession; + private final SSLSocket sslSocket; + + SslSocketWrapper(SSLSocket sslSocket, String peerHost) { + this.sslSocket = sslSocket; + this.sslSession = new FakeSslSession(peerHost); + } + + @Override + public SSLSession getHandshakeSession() { + return this.sslSession; + } + + @Override + public boolean isConnected() { + return sslSocket.isConnected(); + } + + @Override + public SSLParameters getSSLParameters() { + return sslSocket.getSSLParameters(); + } + } + + /** + * Fake SSLSession instance that provides the peer host name to verify for per-rpc check. + */ + static class FakeSslSession extends NoopSslSession { + + private final String peerHost; + + FakeSslSession(String peerHost) { + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java index 068474d70bc..8daeed42a8c 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkArgument; +import static io.grpc.internal.CertificateUtils.createTrustManager; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -425,7 +426,7 @@ static HandshakerSocketFactoryResult handshakerSocketFactoryFrom(ServerCredentia tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); } else if (tlsCreds.getRootCertificates() != null) { try { - tm = OkHttpChannelBuilder.createTrustManager(tlsCreds.getRootCertificates()); + tm = createTrustManager(tlsCreds.getRootCertificates()); } catch (GeneralSecurityException gse) { log.log(Level.FINE, "Exception loading root certificates from credential", gse); return HandshakerSocketFactoryResult.error( diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index 1f716705968..1c98d6ee30d 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -20,6 +20,7 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.times; @@ -244,12 +245,13 @@ public void getUnaryRequest() throws IOException { // GET streams send headers after halfClose is called. verify(mockedFrameWriter, times(0)).synStream( eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); - verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class)); + verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class), + isA(String.class)); byte[] msg = "request".getBytes(Charset.forName("UTF-8")); stream.writeMessage(new ByteArrayInputStream(msg)); stream.halfClose(); - verify(transport).streamReadyToStart(eq(stream)); + verify(transport).streamReadyToStart(eq(stream), any(String.class)); stream.transportState().start(3); verify(mockedFrameWriter) diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index daf5073992e..4ad1266f3b0 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -69,6 +69,7 @@ import io.grpc.Status.Code; import io.grpc.StatusException; import io.grpc.internal.AbstractStream; +import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransport; import io.grpc.internal.FakeClock; @@ -116,6 +117,10 @@ import java.util.logging.Logger; import javax.annotation.Nullable; import javax.net.SocketFactory; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; @@ -190,16 +195,24 @@ public void tearDown() { private void initTransport() throws Exception { startTransport( - DEFAULT_START_STREAM_ID, null, true, null); + DEFAULT_START_STREAM_ID, null, true, null, null); } private void initTransport(int startId) throws Exception { - startTransport(startId, null, true, null); + startTransport(startId, null, true, null, null); } private void startTransport(int startId, @Nullable Runnable connectingCallback, - boolean waitingForConnected, String userAgent) - throws Exception { + boolean waitingForConnected, String userAgent, + HostnameVerifier hostnameVerifier) throws Exception { + startTransport(startId, connectingCallback, waitingForConnected, userAgent, hostnameVerifier, + false); + } + + private void startTransport(int startId, @Nullable Runnable connectingCallback, + boolean waitingForConnected, String userAgent, + HostnameVerifier hostnameVerifier, boolean useSslSocket) + throws Exception { connectedFuture = SettableFuture.create(); final Ticker ticker = new Ticker() { @Override @@ -213,7 +226,11 @@ public Stopwatch get() { return Stopwatch.createUnstarted(ticker); } }; - channelBuilder.socketFactory(new FakeSocketFactory(socket)); + channelBuilder.socketFactory( + new FakeSocketFactory(useSslSocket ? new MockSslSocket(socket) : socket)); + if (hostnameVerifier != null) { + channelBuilder = channelBuilder.hostnameVerifier(hostnameVerifier); + } clientTransport = new OkHttpClientTransport( channelBuilder.buildTransportFactory(), userAgent, @@ -241,7 +258,8 @@ public void testToString() throws Exception { /*userAgent=*/ null, EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); String s = clientTransport.toString(); assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport")); assertTrue("Unexpected: " + s, s.contains(address.toString())); @@ -259,7 +277,8 @@ public void testTransportExecutorWithTooFewThreads() throws Exception { null, EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture()); @@ -300,7 +319,7 @@ public void close() throws SecurityException { assertThat(log.getLevel()).isEqualTo(Level.FINE); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -390,7 +409,7 @@ public void maxMessageSizeShouldBeEnforced() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -443,11 +462,11 @@ public void nextFrameThrowIoException() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); @@ -477,7 +496,7 @@ public void nextFrameThrowIoException() throws Exception { public void nextFrameThrowsError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -498,7 +517,7 @@ public void nextFrameThrowsError() throws Exception { public void nextFrameReturnFalse() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -516,7 +535,7 @@ public void readMessages() throws Exception { final int numMessages = 10; final String message = "Hello Client"; MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(numMessages); @@ -568,7 +587,7 @@ public void receivedDataForInvalidStreamShouldKillConnection() throws Exception public void invalidInboundHeadersCancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -593,7 +612,7 @@ public void invalidInboundHeadersCancelStream() throws Exception { public void invalidInboundTrailersPropagateToMetadata() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -613,7 +632,7 @@ public void invalidInboundTrailersPropagateToMetadata() throws Exception { public void readStatus() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -627,7 +646,7 @@ public void readStatus() throws Exception { public void receiveReset() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -644,7 +663,7 @@ public void receiveReset() throws Exception { public void receiveResetNoError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -665,7 +684,7 @@ public void receiveResetNoError() throws Exception { public void cancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.CANCELLED); @@ -680,7 +699,7 @@ public void cancelStream() throws Exception { public void addDefaultUserAgent() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); Header userAgentHeader = new Header(GrpcUtil.USER_AGENT_KEY.name(), @@ -697,9 +716,9 @@ public void addDefaultUserAgent() throws Exception { @Test public void overrideDefaultUserAgent() throws Exception { - startTransport(3, null, true, "fakeUserAgent"); + startTransport(3, null, true, "fakeUserAgent", null); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); List
expectedHeaders = Arrays.asList(HTTP_SCHEME_HEADER, METHOD_HEADER, @@ -718,7 +737,7 @@ public void overrideDefaultUserAgent() throws Exception { public void cancelStreamForDeadlineExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.DEADLINE_EXCEEDED); @@ -732,7 +751,7 @@ public void writeMessage() throws Exception { initTransport(); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); @@ -747,6 +766,65 @@ public void writeMessage() throws Exception { shutdownAndVerify(); } + @Test + public void perRpcAuthoritySpecified_verificationSkippedInPlainTextConnection() + throws Exception { + initTransport(); + final String message = "Hello Server"; + MockStreamListener listener = new MockStreamListener(); + ClientStream stream = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + stream.start(listener); + InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); + assertEquals(12, input.available()); + stream.writeMessage(input); + stream.flush(); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); + assertEquals(createMessageFrame(message), sentFrame); + stream.cancel(Status.CANCELLED); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_ignoredForNonSslSocket() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> false, false); + ClientStream unused = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_SslSocket_successCase() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> true, true); + ClientStream unused = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_SslSocket_flagDisabled() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> false, true); + ClientStream clientStream = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + assertThat(clientStream).isInstanceOf(OkHttpClientStream.class); + shutdownAndVerify(); + } + @Test public void transportTracer_windowSizeDefault() throws Exception { initTransport(); @@ -773,12 +851,12 @@ public void windowUpdate() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(2); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(2); @@ -843,7 +921,7 @@ public void windowUpdate() throws Exception { public void windowUpdateWithInboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = INITIAL_WINDOW_SIZE / 2 + 1; @@ -880,7 +958,7 @@ public void windowUpdateWithInboundFlowControl() throws Exception { public void outboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -926,7 +1004,7 @@ public void outboundFlowControl_smallWindowSize() throws Exception { setInitialWindowSize(initialOutboundWindowSize); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -969,7 +1047,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { frameHandler().windowUpdate(0, 65535); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -1005,7 +1083,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; @@ -1051,7 +1129,7 @@ public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { public void outboundFlowControlWithInitialWindowSizeChangeInMiddleOfStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; @@ -1086,10 +1164,10 @@ public void stopNormally() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); assertEquals(2, activeStreamCount()); @@ -1116,11 +1194,11 @@ public void receiveGoAway() throws Exception { // start 2 streams. MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); @@ -1143,7 +1221,7 @@ public void receiveGoAway() throws Exception { // But stream 1 should be able to send. final String sentMessage = "Should I also go away?"; - OkHttpClientStream stream = getStream(3); + ClientStream stream = getStream(3); InputStream input = new ByteArrayInputStream(sentMessage.getBytes(UTF_8)); assertEquals(22, input.available()); stream.writeMessage(input); @@ -1175,7 +1253,7 @@ public void streamIdExhausted() throws Exception { initTransport(startId); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1211,11 +1289,11 @@ public void pendingStreamSucceed() throws Exception { setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); String sentMessage = "hello"; @@ -1248,7 +1326,7 @@ public void pendingStreamCancelled() throws Exception { initTransport(); setMaxConcurrentStreams(0); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1267,11 +1345,11 @@ public void pendingStreamFailedByGoAway() throws Exception { setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); @@ -1297,7 +1375,7 @@ public void pendingStreamSucceedAfterShutdown() throws Exception { setMaxConcurrentStreams(0); final MockStreamListener listener = new MockStreamListener(); // The second stream should be pending. - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1321,15 +1399,15 @@ public void pendingStreamFailedByIdExhausted() throws Exception { final MockStreamListener listener2 = new MockStreamListener(); final MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second and third stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); @@ -1353,7 +1431,7 @@ public void pendingStreamFailedByIdExhausted() throws Exception { public void receivingWindowExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1405,7 +1483,7 @@ public void duplexStreamingHeadersShouldNotBeFlushed() throws Exception { private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); verify(frameWriter, timeout(TIME_OUT_MS)).synStream( @@ -1422,7 +1500,7 @@ private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { public void receiveDataWithoutHeader() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1445,7 +1523,7 @@ public void receiveDataWithoutHeader() throws Exception { public void receiveDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1469,7 +1547,7 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception { public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1491,7 +1569,7 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1520,7 +1598,7 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception public void receiveWindowUpdateForUnknownStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1540,7 +1618,7 @@ public void receiveWindowUpdateForUnknownStream() throws Exception { public void shouldBeInitiallyReady() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); @@ -1558,7 +1636,7 @@ public void notifyOnReady() throws Exception { AbstractStream.TransportState.DEFAULT_ONREADY_THRESHOLD - HEADER_LENGTH - 1; setInitialWindowSize(0); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); @@ -1711,7 +1789,7 @@ public void shutdownDuringConnecting() throws Exception { DEFAULT_START_STREAM_ID, connectingCallback, false, - null); + null, null); clientTransport.shutdown(SHUTDOWN_REASON); delayed.set(null); shutdownAndVerify(); @@ -1726,7 +1804,8 @@ public void invalidAuthorityPropagates() { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); String host = clientTransport.getOverridenHost(); int port = clientTransport.getOverridenPort(); @@ -1744,7 +1823,8 @@ public void unreachableServer() throws Exception { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); @@ -1774,7 +1854,8 @@ public void customSocketFactory() throws Exception { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); @@ -1799,7 +1880,8 @@ public void proxy_200() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1848,7 +1930,8 @@ public void proxy_500() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1896,7 +1979,8 @@ public void proxy_immediateServerClose() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1927,7 +2011,8 @@ public void proxy_serverHangs() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.proxySocketTimeout = 10; clientTransport.start(transportListener); @@ -1994,13 +2079,13 @@ public void goAway_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2034,13 +2119,13 @@ public void reset_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); @@ -2076,13 +2161,13 @@ public void shutdownNow_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2107,11 +2192,11 @@ public void finishedStreamRemovedFromInUseState() throws Exception { initTransport(); setMaxConcurrentStreams(1); final MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); + OkHttpClientStream stream = clientTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); - OkHttpClientStream pendingStream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); + OkHttpClientStream pendingStream = clientTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); pendingStream.start(listener); waitForStreamPending(1); clientTransport.finishStream(stream.transportState().id(), Status.OK, PROCESSED, @@ -2151,7 +2236,7 @@ private void waitForStreamPending(int expected) throws Exception { private void assertNewStreamFail() throws Exception { MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); listener.waitUntilStreamClosed(); @@ -2382,6 +2467,124 @@ public InputStream getInputStream() { } } + private static class MockSslSocket extends SSLSocket { + private Socket delegate; + + MockSslSocket(Socket socket) { + delegate = socket; + } + + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void addHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void startHandshake() throws IOException { + + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } + + @Override + public synchronized void close() throws IOException { + delegate.close(); + } + + @Override + public SocketAddress getLocalSocketAddress() { + return delegate.getLocalSocketAddress(); + } + + @Override + public OutputStream getOutputStream() throws IOException { + return delegate.getOutputStream(); + } + + @Override + public InputStream getInputStream() throws IOException { + return delegate.getInputStream(); + } + } + static class PingCallbackImpl implements ClientTransport.PingCallback { int invocationCount; long roundTripTime; diff --git a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java index a21360a89ba..68d1a55ce56 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java @@ -18,8 +18,10 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import static org.junit.Assert.fail; import com.google.common.base.Throwables; +import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ConnectivityState; import io.grpc.ManagedChannel; @@ -32,18 +34,35 @@ import io.grpc.TlsServerCredentials; import io.grpc.internal.testing.TestUtils; import io.grpc.okhttp.internal.Platform; +import io.grpc.stub.ClientCalls; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.TlsTesting; import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.grpc.util.CertificateUtils; import java.io.IOException; import java.io.InputStream; +import java.net.Socket; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Optional; +import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.Assume; import org.junit.Before; import org.junit.Rule; @@ -53,6 +72,7 @@ /** Verify OkHttp's TLS integration. */ @RunWith(JUnit4.class) +@IgnoreJRERequirement public class TlsTest { @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @@ -92,6 +112,305 @@ public void basicTls_succeeds() throws Exception { SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()); } + @Test + public void perRpcAuthorityOverride_hostnameVerification_success() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + SSLSocketFactory sslSocketFactory = TestUtils.newSslSocketFactoryForCa( + Platform.get().getProvider(), TestUtils.loadCert("ca.pem")); + ManagedChannel channel = grpcCleanupRule.register(grpcCleanupRule.register( + OkHttpChannelBuilder.forAddress("localhost", server.getPort()) + .directExecutor() + .sslSocketFactory(sslSocketFactory) + .hostnameVerifier(new HostnameVerifier() { + private int callCount; + @Override + public boolean verify(String hostname, SSLSession session) { + if (++callCount == 1) { + return true; + } + return hostname.equals("foo.test.google.fr"); + } + }) + .build())); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_hostnameVerification_failure_rpcFails() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + SSLSocketFactory sslSocketFactory = TestUtils.newSslSocketFactoryForCa( + Platform.get().getProvider(), TestUtils.loadCert("ca.pem")); + ManagedChannel channel = grpcCleanupRule.register(grpcCleanupRule.register( + OkHttpChannelBuilder.forAddress("localhost", server.getPort()) + .directExecutor() + .sslSocketFactory(sslSocketFactory) + .hostnameVerifier(new HostnameVerifier() { + private int callCount; + @Override + public boolean verify(String hostname, SSLSession session) { + if (++callCount == 1) { + return true; + } + return hostname.equals("foo.test.google.fr"); + } + }) + .build())); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.in"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for hostname verifier failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()).isEqualTo( + "HostNameVerifier verification failed for authority 'foo.test.google.in'"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_hostnameVerification_failure_flagDisabled_rpcSucceeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + SSLSocketFactory sslSocketFactory = TestUtils.newSslSocketFactoryForCa( + Platform.get().getProvider(), TestUtils.loadCert("ca.pem")); + ManagedChannel channel = grpcCleanupRule.register(grpcCleanupRule.register( + OkHttpChannelBuilder.forAddress("localhost", server.getPort()) + .directExecutor() + .sslSocketFactory(sslSocketFactory) + .hostnameVerifier(new HostnameVerifier() { + private int callCount; + @Override + public boolean verify(String hostname, SSLSession session) { + if (++callCount == 1) { + return true; + } + return hostname.equals("foo.test.google.fr"); + } + }) + .build())); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.in"), + SimpleRequest.getDefaultInstance()); + } + + /** + * Test to verify that checkServerTrusted is been called for the per-rpc authority + * verification. Could not verify this indirectly using an invalid authority in the test because + * the SSLParameters.setEndpointIdentificationAlgorithm (an Android 24+ API) is not set to HTTPS + * during the handshake and doesn't verify hostname during rpc call either. + */ + @Test + public void perRpcAuthorityOverride_checkServerTrustedIsCalled() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + FakeX509ExtendedTrustManager fakeTrustManager; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager x509ExtendedTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + fakeTrustManager = new FakeX509ExtendedTrustManager(x509ExtendedTrustManager); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(fakeTrustManager) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.fr"), + SimpleRequest.getDefaultInstance()); + assertThat(fakeTrustManager.checkServerTrustedCalled).isTrue(); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + /** + * Uses a fake Trust Manager to fail authority verification for rpc identified by the call count. + */ + @Test + public void perRpcAuthorityOverride_peerVerificationFails_rpcFails() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + FakeX509ExtendedTrustManager fakeTrustManager; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager x509ExtendedTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + fakeTrustManager = new FakeX509ExtendedTrustManager(x509ExtendedTrustManager); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(fakeTrustManager) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + fakeTrustManager.setFailCheckServerTrusted(); + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.fr"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getCause().getCause()).isInstanceOf(CertificateException.class); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_peerVerificationFails_featureDisabled_rpcSucceeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + FakeX509ExtendedTrustManager fakeTrustManager; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager x509ExtendedTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + fakeTrustManager = new FakeX509ExtendedTrustManager(x509ExtendedTrustManager); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(fakeTrustManager) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + fakeTrustManager.setFailCheckServerTrusted(); + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } + + /** + * This negative test simulates the absence of X509ExtendedTrustManager while still using the + * real trust manager for the connection handshake to happen. + */ + @Test + public void perRpcAuthorityOverride_tlsCreds_noX509ExtendedTrustManager_fails() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.fr"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()).isEqualTo( + "Could not verify authority 'foo.test.google.fr' for the rpc with no " + + "X509ExtendedTrustManager available"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_tlsCreds_noX509ExtendedTrustManager_flagDisabled() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } + @Test public void mtls_succeeds() throws Exception { ServerCredentials serverCreds; @@ -282,6 +601,110 @@ public void hostnameVerifierFails_fails() assertThat(status.getCause()).isInstanceOf(SSLPeerUnverifiedException.class); } + /** Used to simulate the case of X509ExtendedTrustManager not present. */ + private static class FakeTrustManager implements X509TrustManager { + + private final X509TrustManager delegate; + + public FakeTrustManager(X509TrustManager x509ExtendedTrustManager) { + this.delegate = x509ExtendedTrustManager; + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkClientTrusted(x509Certificates, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkServerTrusted(x509Certificates, s); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } + + @IgnoreJRERequirement + private static class FakeX509ExtendedTrustManager extends X509ExtendedTrustManager { + private final X509ExtendedTrustManager delegate; + private boolean checkServerTrustedCalled; + private boolean shouldFailCheckServerTrustedForRpc; + private int numCalls; + + private FakeX509ExtendedTrustManager(X509ExtendedTrustManager delegate) { + this.delegate = delegate; + } + + private void setFailCheckServerTrusted() { + shouldFailCheckServerTrustedForRpc = true; + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + this.checkServerTrustedCalled = true; + if (shouldFailCheckServerTrustedForRpc && ++numCalls > 1) { + throw new CertificateException("Peer verification failed."); + } + delegate.checkServerTrusted(chain, authType, socket); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + delegate.checkServerTrusted(chain, authType, engine); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + delegate.checkServerTrusted(chain, authType); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) { + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) { + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) { + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + } + + private static Optional getX509ExtendedTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return Arrays.stream(trustManagerFactory.getTrustManagers()) + .filter(trustManager -> trustManager instanceof X509ExtendedTrustManager).findFirst(); + } + private static Server server(ServerCredentials creds) throws IOException { return OkHttpServerBuilder.forPort(0, creds) .directExecutor()