Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with WebSocket UpgradeRequest after upgrade is complete #12877

Open
wants to merge 6 commits into
base: jetty-12.1.x
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Issue #11294 - fix access to websocket UpgradeRequest after upgrade
Signed-off-by: Lachlan Roberts <lachlan.p.roberts@gmail.com>
lachlan-roberts committed Mar 5, 2025
commit 9818975ece68fc5981eb44f169fabc0daf932fce
Original file line number Diff line number Diff line change
@@ -36,19 +36,26 @@ class UpgradeRequestDelegate implements UpgradeRequest
{
private final ServerUpgradeRequest request;
private final Map<String, List<String>> headers;
private final List<HttpCookie> cookies;
private final Principal userPrincipal;

UpgradeRequestDelegate(ServerUpgradeRequest request)
{
this.request = request;
this.headers = HttpFields.asMap(request.getHeaders());

Request.AuthenticationState authenticationState = Request.getAuthenticationState(request);
userPrincipal = (authenticationState == null) ? null : authenticationState.getUserPrincipal();

this.cookies = Request.getCookies(request).stream()
.map(org.eclipse.jetty.http.HttpCookie::asJavaNetHttpCookie)
.toList();
}

@Override
public List<HttpCookie> getCookies()
{
return Request.getCookies(request).stream()
.map(org.eclipse.jetty.http.HttpCookie::asJavaNetHttpCookie)
.toList();
return cookies;
}

@Override
@@ -149,8 +156,7 @@ public List<String> getSubProtocols()
@Override
public Principal getUserPrincipal()
{
// TODO: no Principal concept in Jetty core.
return null;
return userPrincipal;
}

@Override
Original file line number Diff line number Diff line change
@@ -19,6 +19,8 @@
import java.util.stream.Collectors;

import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.websocket.api.ExtensionConfig;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.eclipse.jetty.websocket.common.JettyExtensionConfig;
@@ -28,11 +30,16 @@ class UpgradeResponseDelegate implements UpgradeResponse
{
private final ServerUpgradeResponse response;
private final Map<String, List<String>> headers;
private final int status;

UpgradeResponseDelegate(ServerUpgradeResponse response)
{
this.response = response;
this.headers = HttpFields.asMap(response.getHeaders());
this.headers = HttpFields.asMap(response.getHeaders().asImmutable());

// Fake status code as it not set at the time this is created.
HttpVersion httpVersion = response.getRequest().getConnectionMetaData().getHttpVersion();
this.status = (httpVersion == HttpVersion.HTTP_1_1) ? HttpStatus.SWITCHING_PROTOCOLS_101 : HttpStatus.OK_200;
}

@Override
@@ -76,6 +83,6 @@ public List<String> getHeaders(String name)
@Override
public int getStatusCode()
{
return response.getStatus();
return status;
}
}
Original file line number Diff line number Diff line change
@@ -72,6 +72,11 @@
<artifactId>jetty-websocket-jetty-server</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-security</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.websocket.tests;

import java.net.URI;
import java.util.List;
import java.util.concurrent.TimeUnit;

import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.security.AbstractLoginService;
import org.eclipse.jetty.security.AuthenticationState;
import org.eclipse.jetty.security.Constraint;
import org.eclipse.jetty.security.DefaultIdentityService;
import org.eclipse.jetty.security.IdentityService;
import org.eclipse.jetty.security.RolePrincipal;
import org.eclipse.jetty.security.SecurityHandler;
import org.eclipse.jetty.security.UserIdentity;
import org.eclipse.jetty.security.UserPrincipal;
import org.eclipse.jetty.security.authentication.LoginAuthenticator;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.security.Credential;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.eclipse.jetty.websocket.server.WebSocketUpgradeHandler;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class ServerUpgradeRequestTest
{
private Server _server;
private WebSocketClient _client;
private ServerConnector _connector;

private static class TestLoginService extends AbstractLoginService
{
public TestLoginService(IdentityService identityService)
{
setIdentityService(identityService);
}

@Override
protected List<RolePrincipal> loadRoleInfo(UserPrincipal user)
{
return List.of();
}

@Override
protected UserPrincipal loadUserInfo(String username)
{
return new UserPrincipal(username, null)
{
@Override
public boolean authenticate(Object credentials)
{
return true;
}

@Override
public boolean authenticate(Credential c)
{
return true;
}

@Override
public boolean authenticate(UserPrincipal u)
{
return true;
}
};
}
}

private static class TestAuthenticator extends LoginAuthenticator
{
@Override
public String getAuthenticationType()
{
return "TEST";
}

@Override
public AuthenticationState validateRequest(Request request, Response response, org.eclipse.jetty.util.Callback callback)
{
UserIdentity user = login("user123", null, request, response);
if (user != null)
return new UserAuthenticationSucceeded(getAuthenticationType(), user);

Response.writeError(request, response, callback, HttpStatus.FORBIDDEN_403);
return AuthenticationState.SEND_FAILURE;
}
}

@BeforeEach
public void start() throws Exception
{
_server = new Server();
_connector = new ServerConnector(_server);
_server.addConnector(_connector);

WebSocketUpgradeHandler upgradeHandler = WebSocketUpgradeHandler.from(_server, container ->
{
container.addMapping("/", (req, resp, cb) ->
{
resp.getHeaders().put("customHeader", "customHeaderValue");
resp.setAcceptedSubProtocol(req.getSubProtocols().get(0));

return new ServerSocket();
});
});

SecurityHandler.PathMapped securityHandler = new SecurityHandler.PathMapped();
securityHandler.put("/*", Constraint.ANY_USER);
DefaultIdentityService identityService = new DefaultIdentityService();
securityHandler.setLoginService(new TestLoginService(identityService));
securityHandler.setIdentityService(identityService);
securityHandler.setAuthenticator(new TestAuthenticator());
securityHandler.setHandler(upgradeHandler);

_server.setHandler(securityHandler);
_server.start();

_client = new WebSocketClient();
_client.start();
}

@AfterEach
public void stop() throws Exception
{
_client.stop();
_server.stop();
}

@WebSocket
public static class ServerSocket extends EventSocket
{
@Override
public void onMessage(String message)
{
StringBuilder builder = new StringBuilder();

try
{
switch (message)
{
case "getUpgradeRequest" ->
{
UpgradeRequest upgradeRequest = session.getUpgradeRequest();
builder.append("getRequestURI: ").append(upgradeRequest.getRequestURI()).append("\n");
builder.append("getHeaders: ").append(upgradeRequest.getHeaders()).append("\n");
builder.append("getExtensions: ").append(upgradeRequest.getExtensions()).append("\n");
builder.append("getHost: ").append(upgradeRequest.getHost()).append("\n");
builder.append("getHttpVersion: ").append(upgradeRequest.getHttpVersion()).append("\n");
builder.append("getQueryString: ").append(upgradeRequest.getQueryString()).append("\n");
builder.append("getSubProtocols: ").append(upgradeRequest.getSubProtocols()).append("\n");
builder.append("getProtocolVersion: ").append(upgradeRequest.getProtocolVersion()).append("\n");
builder.append("getCookies: ").append(upgradeRequest.getCookies()).append("\n");
builder.append("getUserPrincipal: ").append(upgradeRequest.getUserPrincipal()).append("\n");
builder.append("getOrigin: ").append(upgradeRequest.getOrigin()).append("\n");
builder.append("isSecure: ").append(upgradeRequest.isSecure()).append("\n");
builder.append("getParameterMap: ").append(upgradeRequest.getParameterMap()).append("\n");
}
case "getUpgradeResponse" ->
{
UpgradeResponse upgradeResponse = session.getUpgradeResponse();
builder.append("getHeaders: ").append(upgradeResponse.getHeaders()).append("\n");
builder.append("getExtensions: ").append(upgradeResponse.getExtensions()).append("\n");
builder.append("getStatusCode: ").append(upgradeResponse.getStatusCode()).append("\n");
builder.append("getAcceptedSubProtocol: ").append(upgradeResponse.getAcceptedSubProtocol()).append("\n");
}
default -> throw new IllegalStateException("Unknown message: " + message);
}
}
catch (Exception e)
{
e.printStackTrace(System.err);
throw e;
}

session.sendText(builder.toString(), Callback.NOOP);
}
}

@Test
public void testUpgradeRequest() throws Exception
{
URI uri = new URI("ws://localhost:" + _connector.getLocalPort() + "/?queryParam1=queryParamValue1");
ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setSubProtocols("subProtocol1", "subProtocol2");
upgradeRequest.addExtensions("permessage-deflate");
upgradeRequest.setHeader("Cookie", "cookieHeader1=cookieValue1");
upgradeRequest.setHeader("Origin", "jetty-test");
upgradeRequest.setHeader("CustomRequestHeader", "request-header-value");

EventSocket clientEndpoint = new EventSocket();
Session session = _client.connect(clientEndpoint, uri, upgradeRequest).get(5, TimeUnit.SECONDS);

session.sendText("getUpgradeRequest", Callback.NOOP);
String received = clientEndpoint.textMessages.poll(5, TimeUnit.SECONDS);
assertThat(received, containsString("getRequestURI: " + uri));
assertThat(received, containsString("CustomRequestHeader=[request-header-value]"));
assertThat(received, containsString("getExtensions: [permessage-deflate]"));
assertThat(received, containsString("getHost: localhost"));
assertThat(received, containsString("getHttpVersion: HTTP/1.1"));
assertThat(received, containsString("getQueryString: queryParam1=queryParamValue1"));
assertThat(received, containsString("getSubProtocols: [subProtocol1, subProtocol2]"));
assertThat(received, containsString("getProtocolVersion: 13"));
assertThat(received, containsString("getCookies: [cookieHeader1=cookieValue1]"));
assertThat(received, containsString("getUserPrincipal: user123"));
assertThat(received, containsString("getOrigin: jetty-test"));
assertThat(received, containsString("isSecure: false"));
assertThat(received, containsString("getParameterMap: {queryParam1=[queryParamValue1]}"));

session.close();
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeCode, equalTo(StatusCode.NORMAL));
}

@Test
public void testUpgradeResponse() throws Exception
{
URI uri = new URI("ws://localhost:" + _connector.getLocalPort());

ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setSubProtocols("subProtocol1", "subProtocol2");
upgradeRequest.addExtensions("permessage-deflate");

EventSocket clientEndpoint = new EventSocket();
Session session = _client.connect(clientEndpoint, uri, upgradeRequest).get(5, TimeUnit.SECONDS);

session.sendText("getUpgradeResponse", Callback.NOOP);
String received = clientEndpoint.textMessages.poll(5, TimeUnit.SECONDS);
assertThat(received, containsString("customHeader=[customHeaderValue]"));
assertThat(received, containsString("getExtensions: [permessage-deflate]"));
assertThat(received, containsString("getStatusCode: 101"));
assertThat(received, containsString("getAcceptedSubProtocol: subProtocol1"));

session.close();
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeCode, equalTo(StatusCode.NORMAL));
}
}
Original file line number Diff line number Diff line change
@@ -42,23 +42,56 @@

public class DelegatedServerUpgradeRequest implements JettyServerUpgradeRequest
{
private final boolean upgraded;
private final URI requestURI;
private final String queryString;
private final ServerUpgradeRequest upgradeRequest;
private final HttpServletRequest httpServletRequest;
private final Principal userPrincipal;
private final String origin;
private final boolean isSecure;
private final Map<String, List<String>> headers;
private List<HttpCookie> cookies;
private Map<String, List<String>> parameterMap;

public DelegatedServerUpgradeRequest(ServerUpgradeRequest request)
{
this(request, false);
}

public DelegatedServerUpgradeRequest(ServerUpgradeRequest request, boolean upgraded)
{
this.upgraded = upgraded;
this.httpServletRequest = (HttpServletRequest)request
.getAttribute(WebSocketConstants.WEBSOCKET_WRAPPED_REQUEST_ATTRIBUTE);
this.upgradeRequest = request;
this.headers = HttpFields.asMap(upgradeRequest.getHeaders());
this.queryString = httpServletRequest.getQueryString();
this.userPrincipal = httpServletRequest.getUserPrincipal();
this.headers = HttpFields.asMap(upgradeRequest.getHeaders());
this.origin = httpServletRequest.getHeader(HttpHeader.ORIGIN.asString());
this.isSecure = httpServletRequest.isSecure();

Map<String, String[]> requestParams = httpServletRequest.getParameterMap();
if (requestParams != null)
{
parameterMap = new HashMap<>(requestParams.size());
for (Map.Entry<String, String[]> entry : requestParams.entrySet())
{
parameterMap.put(entry.getKey(), Arrays.asList(entry.getValue()));
}
}

Cookie[] reqCookies = httpServletRequest.getCookies();
if (reqCookies != null)
{
cookies = Arrays.stream(reqCookies)
.map(c -> new HttpCookie(c.getName(), c.getValue()))
.collect(Collectors.toList());
}
else
{
cookies = Collections.emptyList();
}

try
{
@@ -82,21 +115,6 @@ public ServerUpgradeRequest getServerUpgradeRequest()
@Override
public List<HttpCookie> getCookies()
{
if (cookies == null)
{
Cookie[] reqCookies = httpServletRequest.getCookies();
if (reqCookies != null)
{
cookies = Arrays.stream(reqCookies)
.map(c -> new HttpCookie(c.getName(), c.getValue()))
.collect(Collectors.toList());
}
else
{
cookies = Collections.emptyList();
}
}

return cookies;
}

@@ -153,24 +171,12 @@ public String getMethod()
@Override
public String getOrigin()
{
return httpServletRequest.getHeader(HttpHeader.ORIGIN.asString());
return origin;
}

@Override
public Map<String, List<String>> getParameterMap()
{
if (parameterMap == null)
{
Map<String, String[]> requestParams = httpServletRequest.getParameterMap();
if (requestParams != null)
{
parameterMap = new HashMap<>(requestParams.size());
for (Map.Entry<String, String[]> entry : requestParams.entrySet())
{
parameterMap.put(entry.getKey(), Arrays.asList(entry.getValue()));
}
}
}
return parameterMap;
}

@@ -195,6 +201,9 @@ public URI getRequestURI()
@Override
public HttpSession getSession()
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return httpServletRequest.getSession();
}

@@ -219,42 +228,60 @@ public boolean hasSubProtocol(String subprotocol)
@Override
public boolean isSecure()
{
return httpServletRequest.isSecure();
return isSecure;
}

@Override
public X509Certificate[] getCertificates()
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return (X509Certificate[])httpServletRequest.getAttribute("jakarta.servlet.request.X509Certificate");
}

@Override
public HttpServletRequest getHttpServletRequest()
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return httpServletRequest;
}

@Override
public Locale getLocale()
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return httpServletRequest.getLocale();
}

@Override
public Enumeration<Locale> getLocales()
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return httpServletRequest.getLocales();
}

@Override
public SocketAddress getLocalSocketAddress()
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return upgradeRequest.getConnectionMetaData().getLocalSocketAddress();
}

@Override
public SocketAddress getRemoteSocketAddress()
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return upgradeRequest.getConnectionMetaData().getRemoteSocketAddress();
}

@@ -267,12 +294,18 @@ public String getRequestPath()
@Override
public Object getServletAttribute(String name)
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return upgradeRequest.getAttribute(name);
}

@Override
public Map<String, Object> getServletAttributes()
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

Map<String, Object> attributes = new HashMap<>(2);
Enumeration<String> attributeNames = httpServletRequest.getAttributeNames();
while (attributeNames.hasMoreElements())
@@ -286,18 +319,27 @@ public Map<String, Object> getServletAttributes()
@Override
public Map<String, List<String>> getServletParameters()
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return getParameterMap();
}

@Override
public boolean isUserInRole(String role)
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

return httpServletRequest.isUserInRole(role);
}

@Override
public void setServletAttribute(String name, Object value)
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

upgradeRequest.setAttribute(name, value);
}
}
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@
import org.eclipse.jetty.ee10.websocket.server.JettyServerUpgradeResponse;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.HttpStatus;
import org.eclipse.jetty.http.HttpVersion;
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.websocket.api.ExtensionConfig;
import org.eclipse.jetty.websocket.common.JettyExtensionConfig;
@@ -32,36 +33,54 @@

public class DelegatedServerUpgradeResponse implements JettyServerUpgradeResponse
{
private final boolean upgraded;
private final ServerUpgradeResponse upgradeResponse;
private final HttpServletResponse httpServletResponse;
private final Map<String, List<String>> headers;
private final int status;

public DelegatedServerUpgradeResponse(ServerUpgradeResponse response)
{
upgradeResponse = response;
ServletContextResponse servletContextResponse = Response.as(response, ServletContextResponse.class);
this.httpServletResponse = (HttpServletResponse)servletContextResponse.getRequest()
this(response, false);
}

public DelegatedServerUpgradeResponse(ServerUpgradeResponse response, boolean upgraded)
{
this.upgraded = upgraded;
this.upgradeResponse = response;
this.httpServletResponse = (HttpServletResponse)Response.as(response, ServletContextResponse.class).getRequest()
.getAttribute(WebSocketConstants.WEBSOCKET_WRAPPED_RESPONSE_ATTRIBUTE);
this.headers = HttpFields.asMap(upgradeResponse.getHeaders());
this.headers = HttpFields.asMap(upgraded ? upgradeResponse.getHeaders().asImmutable() : upgradeResponse.getHeaders());

// Fake status code if already upgraded, as it not set at the time this is created.
HttpVersion httpVersion = response.getRequest().getConnectionMetaData().getHttpVersion();
this.status = (httpVersion == HttpVersion.HTTP_1_1) ? HttpStatus.SWITCHING_PROTOCOLS_101 : HttpStatus.OK_200;
}

@Override
public void addHeader(String name, String value)
{
// TODO: This should go to the httpServletResponse for headers but then it won't do interception of the websocket headers
// which are done through the jetty-core Response wrapping ServerUpgradeResponse done by websocket-core.
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

upgradeResponse.getHeaders().add(name, value);
}

@Override
public void setHeader(String name, String value)
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

headers.put(name, List.of(value));
}

@Override
public void setHeader(String name, List<String> values)
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

headers.put(name, values);
}

@@ -104,24 +123,36 @@ public List<String> getHeaders(String name)
@Override
public int getStatusCode()
{
return httpServletResponse.getStatus();
if (upgraded)
return status;
else
return httpServletResponse.getStatus();
}

@Override
public void sendForbidden(String message) throws IOException
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

httpServletResponse.sendError(HttpStatus.FORBIDDEN_403, message);
}

@Override
public void setAcceptedSubProtocol(String protocol)
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

upgradeResponse.setAcceptedSubProtocol(protocol);
}

@Override
public void setExtensions(List<ExtensionConfig> configs)
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

upgradeResponse.setExtensions(configs.stream()
.map(c -> new org.eclipse.jetty.websocket.core.ExtensionConfig(c.getName(), c.getParameters()))
.collect(Collectors.toList()));
@@ -130,18 +161,27 @@ public void setExtensions(List<ExtensionConfig> configs)
@Override
public void setStatusCode(int statusCode)
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

httpServletResponse.setStatus(statusCode);
}

@Override
public boolean isCommitted()
{
if (upgraded)
return true;

return httpServletResponse.isCommitted();
}

@Override
public void sendError(int statusCode, String message) throws IOException
{
if (upgraded)
throw new IllegalStateException("Already Upgraded to WebSocket");

httpServletResponse.sendError(statusCode, message);
}
}
Original file line number Diff line number Diff line change
@@ -40,8 +40,8 @@ public JettyServerFrameHandlerFactory(JettyWebSocketServerContainer container, W
public FrameHandler newFrameHandler(Object websocketPojo, ServerUpgradeRequest upgradeRequest, ServerUpgradeResponse upgradeResponse)
{
JettyWebSocketFrameHandler frameHandler = super.newJettyFrameHandler(websocketPojo);
frameHandler.setUpgradeRequest(new DelegatedServerUpgradeRequest(upgradeRequest));
frameHandler.setUpgradeResponse(new DelegatedServerUpgradeResponse(upgradeResponse));
frameHandler.setUpgradeRequest(new DelegatedServerUpgradeRequest(upgradeRequest, true));
frameHandler.setUpgradeResponse(new DelegatedServerUpgradeResponse(upgradeResponse, true));
return frameHandler;
}
}
Original file line number Diff line number Diff line change
@@ -36,49 +36,29 @@
import org.eclipse.jetty.server.Response;
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.util.Callback;
import org.eclipse.jetty.util.security.Credential;
import org.eclipse.jetty.websocket.api.Callback;
import org.eclipse.jetty.websocket.api.Session;
import org.eclipse.jetty.websocket.api.StatusCode;
import org.eclipse.jetty.websocket.api.UpgradeRequest;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketError;
import org.eclipse.jetty.websocket.api.annotations.OnWebSocketOpen;
import org.eclipse.jetty.websocket.api.UpgradeResponse;
import org.eclipse.jetty.websocket.api.annotations.WebSocket;
import org.eclipse.jetty.websocket.client.ClientUpgradeRequest;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import static org.eclipse.jetty.websocket.api.Callback.NOOP;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class ServerUpgradeRequestTest
{
private Server _server;
private ServerConnector _connector;
private WebSocketClient _client;

@WebSocket
public static class MyEndpoint
{
@OnWebSocketOpen
public void onOpen(Session session) throws Exception
{
UpgradeRequest upgradeRequest = session.getUpgradeRequest();
session.sendText("userPrincipal=" + upgradeRequest.getUserPrincipal(), NOOP);
session.sendText("requestURI=" + upgradeRequest.getRequestURI(), NOOP);
session.close();
}

@OnWebSocketError
public void onError(Throwable t)
{
t.printStackTrace();
}
}
private ServerConnector _connector;

private static class TestLoginService extends AbstractLoginService
{
@@ -128,7 +108,7 @@ public String getAuthenticationType()
}

@Override
public AuthenticationState validateRequest(Request request, Response response, Callback callback)
public AuthenticationState validateRequest(Request request, Response response, org.eclipse.jetty.util.Callback callback)
{
UserIdentity user = login("user123", null, request, response);
if (user != null)
@@ -140,19 +120,20 @@ public AuthenticationState validateRequest(Request request, Response response, C
}

@BeforeEach
public void before() throws Exception
public void start() throws Exception
{
_server = new Server();
_connector = new ServerConnector(_server);
_server.addConnector(_connector);

ServletContextHandler contextHandler = new ServletContextHandler();
contextHandler.setContextPath("/context1");
JettyWebSocketServletContainerInitializer.configure(contextHandler, ((servletContext, serverContainer) ->
{
serverContainer.addMapping("/ws", MyEndpoint.class);
}));
_server.setHandler(contextHandler);
ServletContextHandler servletContextHandler = new ServletContextHandler();
JettyWebSocketServletContainerInitializer.configure(servletContextHandler, (servletContext, container) ->
container.addMapping("/", (req, resp) ->
{
resp.setHeader("customHeader", "customHeaderValue");
resp.setAcceptedSubProtocol(req.getSubProtocols().get(0));
return new ServerSocket();
}));

DefaultIdentityService identityService = new DefaultIdentityService();
LoginService loginService = new TestLoginService(identityService);
@@ -165,34 +146,128 @@ public void before() throws Exception
securityHandler.setConstraintMappings(List.of(constraintMapping));
securityHandler.setLoginService(loginService);
securityHandler.setIdentityService(identityService);
contextHandler.setSecurityHandler(securityHandler);
servletContextHandler.setSecurityHandler(securityHandler);
securityHandler.setAuthenticator(new TestAuthenticator());

_server.setHandler(servletContextHandler);
_server.start();

_client = new WebSocketClient();
_client.start();
}

@AfterEach
public void after() throws Exception
public void stop() throws Exception
{
_client.stop();
_server.stop();
}

@WebSocket
public static class ServerSocket extends EventSocket
{
@Override
public void onMessage(String message)
{
StringBuilder builder = new StringBuilder();

try
{
switch (message)
{
case "getUpgradeRequest" ->
{
UpgradeRequest upgradeRequest = session.getUpgradeRequest();
builder.append("getRequestURI: ").append(upgradeRequest.getRequestURI()).append("\n");
builder.append("getHeaders: ").append(upgradeRequest.getHeaders()).append("\n");
builder.append("getExtensions: ").append(upgradeRequest.getExtensions()).append("\n");
builder.append("getHost: ").append(upgradeRequest.getHost()).append("\n");
builder.append("getHttpVersion: ").append(upgradeRequest.getHttpVersion()).append("\n");
builder.append("getQueryString: ").append(upgradeRequest.getQueryString()).append("\n");
builder.append("getSubProtocols: ").append(upgradeRequest.getSubProtocols()).append("\n");
builder.append("getProtocolVersion: ").append(upgradeRequest.getProtocolVersion()).append("\n");
builder.append("getCookies: ").append(upgradeRequest.getCookies()).append("\n");
builder.append("getUserPrincipal: ").append(upgradeRequest.getUserPrincipal()).append("\n");
builder.append("getOrigin: ").append(upgradeRequest.getOrigin()).append("\n");
builder.append("isSecure: ").append(upgradeRequest.isSecure()).append("\n");
builder.append("getParameterMap: ").append(upgradeRequest.getParameterMap()).append("\n");
}
case "getUpgradeResponse" ->
{
UpgradeResponse upgradeResponse = session.getUpgradeResponse();
builder.append("getHeaders: ").append(upgradeResponse.getHeaders()).append("\n");
builder.append("getExtensions: ").append(upgradeResponse.getExtensions()).append("\n");
builder.append("getStatusCode: ").append(upgradeResponse.getStatusCode()).append("\n");
builder.append("getAcceptedSubProtocol: ").append(upgradeResponse.getAcceptedSubProtocol()).append("\n");
}
default -> throw new IllegalStateException("Unknown message: " + message);
}
}
catch (Exception e)
{
e.printStackTrace(System.err);
throw e;
}

session.sendText(builder.toString(), Callback.NOOP);
}
}

@Test
public void test() throws Exception
public void testUpgradeRequest() throws Exception
{
URI uri = URI.create("ws://localhost:" + _connector.getLocalPort() + "/context1/ws");
URI uri = new URI("ws://localhost:" + _connector.getLocalPort() + "/?queryParam1=queryParamValue1");
ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setSubProtocols("subProtocol1", "subProtocol2");
upgradeRequest.addExtensions("permessage-deflate");
upgradeRequest.setHeader("Cookie", "cookieHeader1=cookieValue1");
upgradeRequest.setHeader("Origin", "jetty-test");
upgradeRequest.setHeader("CustomRequestHeader", "request-header-value");

EventSocket clientEndpoint = new EventSocket();
assertNotNull(_client.connect(clientEndpoint, uri));
Session session = _client.connect(clientEndpoint, uri, upgradeRequest).get(5, TimeUnit.SECONDS);

session.sendText("getUpgradeRequest", Callback.NOOP);
String received = clientEndpoint.textMessages.poll(5, TimeUnit.SECONDS);
assertThat(received, containsString("getRequestURI: " + uri));
assertThat(received, containsString("CustomRequestHeader=[request-header-value]"));
assertThat(received, containsString("getExtensions: [permessage-deflate]"));
assertThat(received, containsString("getHost: localhost"));
assertThat(received, containsString("getHttpVersion: HTTP/1.1"));
assertThat(received, containsString("getQueryString: queryParam1=queryParamValue1"));
assertThat(received, containsString("getSubProtocols: [subProtocol1, subProtocol2]"));
assertThat(received, containsString("getProtocolVersion: 13"));
assertThat(received, containsString("getCookies: [cookieHeader1=\"cookieValue1\"]"));
assertThat(received, containsString("getUserPrincipal: user123"));
assertThat(received, containsString("getOrigin: jetty-test"));
assertThat(received, containsString("isSecure: false"));
assertThat(received, containsString("getParameterMap: {queryParam1=[queryParamValue1]}"));

String msg = clientEndpoint.textMessages.poll(5, TimeUnit.SECONDS);
assertThat(msg, equalTo("userPrincipal=user123"));
session.close();
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeCode, equalTo(StatusCode.NORMAL));
}

@Test
public void testUpgradeResponse() throws Exception
{
URI uri = new URI("ws://localhost:" + _connector.getLocalPort());

ClientUpgradeRequest upgradeRequest = new ClientUpgradeRequest();
upgradeRequest.setSubProtocols("subProtocol1", "subProtocol2");
upgradeRequest.addExtensions("permessage-deflate");

EventSocket clientEndpoint = new EventSocket();
Session session = _client.connect(clientEndpoint, uri, upgradeRequest).get(5, TimeUnit.SECONDS);

msg = clientEndpoint.textMessages.poll(5, TimeUnit.SECONDS);
assertThat(msg, equalTo("requestURI=ws://localhost:" + _connector.getLocalPort() + "/context1/ws"));
session.sendText("getUpgradeResponse", Callback.NOOP);
String received = clientEndpoint.textMessages.poll(5, TimeUnit.SECONDS);
assertThat(received, containsString("customHeader=[customHeaderValue]"));
assertThat(received, containsString("getExtensions: [permessage-deflate]"));
assertThat(received, containsString("getStatusCode: 101"));
assertThat(received, containsString("getAcceptedSubProtocol: subProtocol1"));

session.close();
assertTrue(clientEndpoint.closeLatch.await(5, TimeUnit.SECONDS));
assertThat(clientEndpoint.closeCode, equalTo(StatusCode.NORMAL));
}