Skip to content

Commit 106cc9a

Browse files
committed
Add support for websockets (#4578)
1 parent 8e08cbf commit 106cc9a

File tree

4 files changed

+349
-8
lines changed

4 files changed

+349
-8
lines changed

scapy/contrib/websocket.py

+259
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
# SPDX-License-Identifier: GPL-2.0-or-later
2+
# This file is part of Scapy
3+
# See https://scapy.net/ for more information
4+
# Copyright (C) 2024 Lucas Drufva <[email protected]>
5+
6+
# scapy.contrib.description = WebSocket
7+
# scapy.contrib.status = loads
8+
9+
# Based on rfc6455
10+
11+
import struct
12+
import base64
13+
import zlib
14+
from hashlib import sha1
15+
from scapy.fields import (BitFieldLenField, Field, BitField, BitEnumField, ConditionalField, XNBytesField)
16+
from scapy.layers.http import HTTPRequest, HTTPResponse
17+
from scapy.layers.inet import TCP
18+
from scapy.packet import Packet
19+
from scapy.error import Scapy_Exception
20+
import logging
21+
22+
23+
class PayloadLenField(BitFieldLenField):
24+
25+
def __init__(self, name, default, length_of, size=0, tot_size=0, end_tot_size=0):
26+
# Initialize with length_of (like in BitFieldLenField) and lengthFrom (like in BitLenField)
27+
super().__init__(name, default, size, length_of=length_of, tot_size=tot_size, end_tot_size=end_tot_size)
28+
29+
def getfield(self, pkt, s):
30+
s, _ = s
31+
# Get the 7-bit field (first byte)
32+
length_byte = s[0] & 0x7F
33+
s = s[1:]
34+
35+
if length_byte <= 125:
36+
# 7-bit length
37+
return s, length_byte
38+
elif length_byte == 126:
39+
# 16-bit length
40+
length = struct.unpack("!H", s[:2])[0] # Read 2 bytes
41+
s = s[2:]
42+
return s, length
43+
elif length_byte == 127:
44+
# 64-bit length
45+
length = struct.unpack("!Q", s[:8])[0] # Read 8 bytes
46+
s = s[8:]
47+
return s, length
48+
49+
def addfield(self, pkt, s, val):
50+
p_field, p_val = pkt.getfield_and_val(self.length_of)
51+
val = p_field.i2len(pkt, p_val)
52+
53+
if val <= 125:
54+
self.size = 7
55+
return super().addfield(pkt, s, val)
56+
elif val <= 0xFFFF:
57+
self.size = 7+16
58+
s, _, masked = s
59+
return s + struct.pack("!BH", 126 | masked, val)
60+
elif val <= 0xFFFFFFFFFFFFFFFF:
61+
self.size = 7+64
62+
s, _, masked = s
63+
return s + struct.pack("!BQ", 127 | masked, val)
64+
else:
65+
raise Scapy_Exception("%s: Payload length too large" %
66+
self.__class__.__name__)
67+
68+
69+
70+
class PayloadField(Field):
71+
"""
72+
Field for handling raw byte payloads with dynamic size.
73+
The length of the payload is described by a preceding PayloadLenField.
74+
"""
75+
__slots__ = ["lengthFrom"]
76+
77+
def __init__(self, name, lengthFrom):
78+
"""
79+
:param name: Field name
80+
:param lengthFrom: Field name that provides the length of the payload
81+
"""
82+
super(PayloadField, self).__init__(name, None)
83+
self.lengthFrom = lengthFrom
84+
85+
def getfield(self, pkt, s):
86+
# Fetch the length from the field that specifies the length
87+
length = getattr(pkt, self.lengthFrom)
88+
payloadData = s[:length]
89+
90+
if pkt.mask:
91+
key = struct.pack("I", pkt.maskingKey)[::-1]
92+
data_int = int.from_bytes(payloadData, 'big')
93+
mask_repeated = key * (len(payloadData) // 4) + key[: len(payloadData) % 4]
94+
mask_int = int.from_bytes(mask_repeated, 'big')
95+
payloadData = (data_int ^ mask_int).to_bytes(len(payloadData), 'big')
96+
97+
if("permessage-deflate" in pkt.extensions):
98+
try:
99+
payloadData = pkt.decoder[0](payloadData + b"\x00\x00\xff\xff")
100+
except Exception:
101+
logging.debug("Failed to decompress payload", payloadData)
102+
103+
return s[length:], payloadData
104+
105+
def addfield(self, pkt, s, val):
106+
if pkt.mask:
107+
key = struct.pack("I", pkt.maskingKey)[::-1]
108+
data_int = int.from_bytes(val, 'big')
109+
mask_repeated = key * (len(val) // 4) + key[: len(val) % 4]
110+
mask_int = int.from_bytes(mask_repeated, 'big')
111+
val = (data_int ^ mask_int).to_bytes(len(val), 'big')
112+
113+
return s + bytes(val)
114+
115+
def i2len(self, pkt, val):
116+
# Length of the payload in bytes
117+
return len(val)
118+
119+
class WebSocket(Packet):
120+
__slots__ = ["extensions", "decoder"]
121+
122+
name = "WebSocket"
123+
fields_desc = [
124+
BitField("fin", 0, 1),
125+
BitField("rsv", 0, 3),
126+
BitEnumField("opcode", 0, 4,
127+
{
128+
0x0: "none",
129+
0x1: "text",
130+
0x2: "binary",
131+
0x8: "close",
132+
0x9: "ping",
133+
0xA: "pong",
134+
}),
135+
BitField("mask", 0, 1),
136+
PayloadLenField("payloadLen", 0, length_of="wsPayload", size=1),
137+
ConditionalField(XNBytesField("maskingKey", 0, sz=4), lambda pkt: pkt.mask == 1),
138+
PayloadField("wsPayload", lengthFrom="payloadLen")
139+
]
140+
141+
def __init__(self, pkt=None, extensions=[], decoder=None, *args, **fields):
142+
self.extensions = extensions
143+
self.decoder = decoder
144+
super().__init__(_pkt=pkt, *args, **fields)
145+
146+
def extract_padding(self, s):
147+
return '', s
148+
149+
@classmethod
150+
def tcp_reassemble(cls, data, metadata, session):
151+
# data = the reassembled data from the same request/flow
152+
# metadata = empty dictionary, that can be used to store data
153+
# during TCP reassembly
154+
# session = a dictionary proper to the bidirectional TCP session,
155+
# that can be used to store anything
156+
# [...]
157+
# If the packet is available, return it. Otherwise don't.
158+
# Whenever you return a packet, the buffer will be discarded.
159+
160+
161+
HANDSHAKE_STATE_CLIENT_OPEN = 0
162+
HANDSHAKE_STATE_SERVER_OPEN = 1
163+
HANDSHAKE_STATE_OPEN = 2
164+
165+
if "handshake-state" not in session:
166+
session["handshake-state"] = HANDSHAKE_STATE_CLIENT_OPEN
167+
168+
if "extensions" not in session:
169+
session["extensions"] = {}
170+
171+
172+
if session["handshake-state"] == HANDSHAKE_STATE_CLIENT_OPEN:
173+
ht = HTTPRequest(data)
174+
175+
if ht.Method != b"GET":
176+
return None
177+
178+
if not ht.Upgrade or ht.Upgrade.lower() != b"websocket":
179+
return None
180+
181+
if b"Sec-WebSocket-Key" not in ht.Unknown_Headers:
182+
return None
183+
184+
185+
session["handshake-key"] = ht.Unknown_Headers[b"Sec-WebSocket-Key"]
186+
187+
if "original" in metadata:
188+
session["server-port"] = metadata["original"][TCP].dport
189+
else:
190+
print("No original packet")
191+
192+
session["handshake-state"] = HANDSHAKE_STATE_SERVER_OPEN
193+
194+
return ht
195+
196+
elif session["handshake-state"] == HANDSHAKE_STATE_SERVER_OPEN:
197+
ht = HTTPResponse(data)
198+
199+
if not ht.Upgrade.lower() == b"websocket":
200+
return None
201+
202+
# Verify key-accept handshake:
203+
correct_accept = base64.b64encode(sha1(session["handshake-key"] + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode()).digest())
204+
if ht.Unknown_Headers[b"Sec-WebSocket-Accept"] != correct_accept:
205+
#TODO handle or Logg wrong accept key
206+
pass
207+
208+
if b"Sec-WebSocket-Extensions" in ht.Unknown_Headers:
209+
session["extensions"] = {}
210+
for extension in ht.Unknown_Headers[b"Sec-WebSocket-Extensions"].decode().strip().split(";"):
211+
key_value_pair = extension.split("=", 1) + [None]
212+
session["extensions"][key_value_pair[0].strip()] = key_value_pair[1]
213+
214+
if "permessage-deflate" in session["extensions"]:
215+
def create_decompressor(window_bits):
216+
decoder = zlib.decompressobj(wbits=-window_bits)
217+
def decomp(data):
218+
nonlocal decoder
219+
return decoder.decompress(data, 0)
220+
221+
def reset():
222+
nonlocal decoder
223+
nonlocal window_bits
224+
decoder = zlib.decompressobj(wbits=-window_bits)
225+
226+
return (decomp, reset)
227+
228+
# Default values
229+
client_wb = 12
230+
server_wb = 15
231+
232+
# Check for new values in extensions header
233+
if "client_max_window_bits" in session["extensions"]:
234+
client_wb = int(session["extensions"]["client_max_window_bits"])
235+
236+
if "server_max_window_bits" in session["extensions"]:
237+
server_wb = int(session["extensions"]["server_max_window_bits"])
238+
239+
240+
session["server-decoder"] = create_decompressor(client_wb)
241+
session["client-decoder"] = create_decompressor(server_wb)
242+
243+
244+
session["handshake-state"] = HANDSHAKE_STATE_OPEN
245+
246+
return ht
247+
248+
249+
# Handshake is done:
250+
if "original" not in metadata:
251+
return
252+
253+
if "permessage-deflate" in session["extensions"]:
254+
is_server = True if metadata["original"][TCP].sport == session["server-port"] else False
255+
ws = WebSocket(bytes(data), extensions=session["extensions"], decoder = session["server-decoder"] if is_server else session["client-decoder"])
256+
return ws
257+
else:
258+
ws = WebSocket(bytes(data), extensions=session["extensions"])
259+
return ws

scapy/layers/http.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -529,10 +529,10 @@ def do_dissect(self, s):
529529
"""From the HTTP packet string, populate the scapy object"""
530530
first_line, body = _dissect_headers(self, s)
531531
try:
532-
Method, Path, HTTPVersion = re.split(br"\s+", first_line, maxsplit=2)
533-
self.setfieldval('Method', Method)
534-
self.setfieldval('Path', Path)
535-
self.setfieldval('Http_Version', HTTPVersion)
532+
method_path_version = re.split(br"\s+", first_line, maxsplit=2) + [None]
533+
self.setfieldval('Method', method_path_version[0])
534+
self.setfieldval('Path', method_path_version[1])
535+
self.setfieldval('Http_Version', method_path_version[2])
536536
except ValueError:
537537
pass
538538
if body:
@@ -573,10 +573,10 @@ def do_dissect(self, s):
573573
''' From the HTTP packet string, populate the scapy object '''
574574
first_line, body = _dissect_headers(self, s)
575575
try:
576-
HTTPVersion, Status, Reason = re.split(br"\s+", first_line, maxsplit=2)
577-
self.setfieldval('Http_Version', HTTPVersion)
578-
self.setfieldval('Status_Code', Status)
579-
self.setfieldval('Reason_Phrase', Reason)
576+
version_status_reason = re.split(br"\s+", first_line, maxsplit=2) + [None]
577+
self.setfieldval('Http_Version', version_status_reason[0])
578+
self.setfieldval('Status_Code', version_status_reason[1])
579+
self.setfieldval('Reason_Phrase', version_status_reason[2])
580580
except ValueError:
581581
pass
582582
if body:

test/contrib/websocket.uts

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# WebSocket layer unit tests
2+
# Copyright (C) 2024 Lucas Drufva <[email protected]>
3+
#
4+
# Type the following command to launch start the tests:
5+
# $ test/run_tests -P "load_contrib('websocket')" -t test/contrib/websocket.uts
6+
7+
+ Syntax check
8+
= Import the WebSocket layer
9+
from scapy.contrib.websocket import *
10+
11+
+ WebSocket protocol test
12+
= Packet instantiation
13+
pkt = WebSocket(wsPayload=b"Hello, world!", opcode="text", mask=True, maskingKey=0x11223344)
14+
15+
assert pkt.wsPayload == b"Hello, world!"
16+
assert pkt.mask == True
17+
assert pkt.maskingKey == 0x11223344
18+
19+
20+
= Packet dissection
21+
raw = b'\x01\x0dHello, world!'
22+
pkt = WebSocket(raw)
23+
24+
assert pkt.fin == 0
25+
assert pkt.rsv == 0
26+
assert pkt.opcode == 0x1
27+
assert pkt.mask == False
28+
assert pkt.payloadLen == 13
29+
assert pkt.wsPayload == b'Hello, world!'
30+
31+
= Dissect masked packet
32+
raw = b'\x01\x8d\x11\x22\x33\x44\x59\x47\x5f\x28\x7e\x0e\x13\x33\x7e\x50\x5f\x20\x30'
33+
pkt = WebSocket(raw)
34+
35+
assert pkt.fin == 0
36+
assert pkt.rsv == 0
37+
assert pkt.opcode == 0x1
38+
assert pkt.mask == True
39+
assert pkt.payloadLen == 13
40+
assert pkt.wsPayload == b'Hello, world!'
41+
42+
= Session with compression
43+
44+
bind_layers(TCP, WebSocket, dport=5000)
45+
bind_layers(TCP, WebSocket, sport=5000)
46+
47+
from scapy.sessions import TCPSession
48+
49+
filename = scapy_path("/test/pcaps/websocket_compressed_session.pcap")
50+
pkts = sniff(offline=filename, session=TCPSession)
51+
52+
assert len(pkts) == 13
53+
54+
assert pkts[7][WebSocket].wsPayload == b'Hello'
55+
assert pkts[8][WebSocket].wsPayload == b'"Hello"'
56+
assert pkts[10][WebSocket].wsPayload == b'Hello2'
57+
assert pkts[11][WebSocket].wsPayload == b'"Hello2"'
58+
59+
= Create packet with long payload
60+
pkt = WebSocket(wsPayload=b"a"*126, opcode="text")
61+
62+
assert bytes(pkt) == b'\x01\x7e\x00\x7e\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61'
63+
64+
= Dissect packet with long payload
65+
raw = b'\x01\x7e\x00\x7e\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61\x61'
66+
pkt = WebSocket(raw)
67+
68+
assert pkt.payloadLen == 126
69+
assert pkt.wsPayload == b'a'*126
70+
71+
= Create packet with very long payload
72+
pkt = WebSocket(wsPayload=b"a"*65536, opcode="text")
73+
74+
assert bytes(pkt) == b'\x01\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + b'a'*65536
75+
76+
= Dissect packet with very long payload
77+
raw = b'\x01\x7f\x00\x00\x00\x00\x00\x01\x00\x00' + b'a'*65536
78+
pkt = WebSocket(raw)
79+
80+
assert pkt.payloadLen == 65536
81+
assert pkt.wsPayload == b'a'*65536
82+
1.51 KB
Binary file not shown.

0 commit comments

Comments
 (0)