1
+ # SPDX-License-Identifier: GPL-2.0-or-later
2
+ # This file is part of Scapy
3
+ # See https://scapy.net/ for more information
4
+ # Copyright (C) 2024 Lucas Drufva <[email protected] >
5
+
6
+ # scapy.contrib.description = WebSocket
7
+ # scapy.contrib.status = loads
8
+
9
+ # Based on rfc6455
10
+
11
+ import struct
12
+ import base64
13
+ import zlib
14
+ from hashlib import sha1
15
+ from scapy .fields import (BitFieldLenField , Field , BitField , BitEnumField , ConditionalField , XNBytesField )
16
+ from scapy .layers .http import HTTPRequest , HTTPResponse
17
+ from scapy .layers .inet import TCP
18
+ from scapy .packet import Packet
19
+ from scapy .error import Scapy_Exception
20
+ import logging
21
+
22
+
23
+ class PayloadLenField (BitFieldLenField ):
24
+
25
+ def __init__ (self , name , default , length_of , size = 0 , tot_size = 0 , end_tot_size = 0 ):
26
+ # Initialize with length_of (like in BitFieldLenField) and lengthFrom (like in BitLenField)
27
+ super ().__init__ (name , default , size , length_of = length_of , tot_size = tot_size , end_tot_size = end_tot_size )
28
+
29
+ def getfield (self , pkt , s ):
30
+ s , _ = s
31
+ # Get the 7-bit field (first byte)
32
+ length_byte = s [0 ] & 0x7F
33
+ s = s [1 :]
34
+
35
+ if length_byte <= 125 :
36
+ # 7-bit length
37
+ return s , length_byte
38
+ elif length_byte == 126 :
39
+ # 16-bit length
40
+ length = struct .unpack ("!H" , s [:2 ])[0 ] # Read 2 bytes
41
+ s = s [2 :]
42
+ return s , length
43
+ elif length_byte == 127 :
44
+ # 64-bit length
45
+ length = struct .unpack ("!Q" , s [:8 ])[0 ] # Read 8 bytes
46
+ s = s [8 :]
47
+ return s , length
48
+
49
+ def addfield (self , pkt , s , val ):
50
+ p_field , p_val = pkt .getfield_and_val (self .length_of )
51
+ val = p_field .i2len (pkt , p_val )
52
+
53
+ if val <= 125 :
54
+ self .size = 7
55
+ return super ().addfield (pkt , s , val )
56
+ elif val <= 0xFFFF :
57
+ self .size = 7 + 16
58
+ s , _ , masked = s
59
+ return s + struct .pack ("!BH" , 126 | masked , val )
60
+ elif val <= 0xFFFFFFFFFFFFFFFF :
61
+ self .size = 7 + 64
62
+ s , _ , masked = s
63
+ return s + struct .pack ("!BQ" , 127 | masked , val )
64
+ else :
65
+ raise Scapy_Exception ("%s: Payload length too large" %
66
+ self .__class__ .__name__ )
67
+
68
+
69
+
70
+ class PayloadField (Field ):
71
+ """
72
+ Field for handling raw byte payloads with dynamic size.
73
+ The length of the payload is described by a preceding PayloadLenField.
74
+ """
75
+ __slots__ = ["lengthFrom" ]
76
+
77
+ def __init__ (self , name , lengthFrom ):
78
+ """
79
+ :param name: Field name
80
+ :param lengthFrom: Field name that provides the length of the payload
81
+ """
82
+ super (PayloadField , self ).__init__ (name , None )
83
+ self .lengthFrom = lengthFrom
84
+
85
+ def getfield (self , pkt , s ):
86
+ # Fetch the length from the field that specifies the length
87
+ length = getattr (pkt , self .lengthFrom )
88
+ payloadData = s [:length ]
89
+
90
+ if pkt .mask :
91
+ key = struct .pack ("I" , pkt .maskingKey )[::- 1 ]
92
+ data_int = int .from_bytes (payloadData , 'big' )
93
+ mask_repeated = key * (len (payloadData ) // 4 ) + key [: len (payloadData ) % 4 ]
94
+ mask_int = int .from_bytes (mask_repeated , 'big' )
95
+ payloadData = (data_int ^ mask_int ).to_bytes (len (payloadData ), 'big' )
96
+
97
+ if ("permessage-deflate" in pkt .extensions ):
98
+ try :
99
+ payloadData = pkt .decoder [0 ](payloadData + b"\x00 \x00 \xff \xff " )
100
+ except Exception :
101
+ logging .debug ("Failed to decompress payload" , payloadData )
102
+
103
+ return s [length :], payloadData
104
+
105
+ def addfield (self , pkt , s , val ):
106
+ if pkt .mask :
107
+ key = struct .pack ("I" , pkt .maskingKey )[::- 1 ]
108
+ data_int = int .from_bytes (val , 'big' )
109
+ mask_repeated = key * (len (val ) // 4 ) + key [: len (val ) % 4 ]
110
+ mask_int = int .from_bytes (mask_repeated , 'big' )
111
+ val = (data_int ^ mask_int ).to_bytes (len (val ), 'big' )
112
+
113
+ return s + bytes (val )
114
+
115
+ def i2len (self , pkt , val ):
116
+ # Length of the payload in bytes
117
+ return len (val )
118
+
119
+ class WebSocket (Packet ):
120
+ __slots__ = ["extensions" , "decoder" ]
121
+
122
+ name = "WebSocket"
123
+ fields_desc = [
124
+ BitField ("fin" , 0 , 1 ),
125
+ BitField ("rsv" , 0 , 3 ),
126
+ BitEnumField ("opcode" , 0 , 4 ,
127
+ {
128
+ 0x0 : "none" ,
129
+ 0x1 : "text" ,
130
+ 0x2 : "binary" ,
131
+ 0x8 : "close" ,
132
+ 0x9 : "ping" ,
133
+ 0xA : "pong" ,
134
+ }),
135
+ BitField ("mask" , 0 , 1 ),
136
+ PayloadLenField ("payloadLen" , 0 , length_of = "wsPayload" , size = 1 ),
137
+ ConditionalField (XNBytesField ("maskingKey" , 0 , sz = 4 ), lambda pkt : pkt .mask == 1 ),
138
+ PayloadField ("wsPayload" , lengthFrom = "payloadLen" )
139
+ ]
140
+
141
+ def __init__ (self , pkt = None , extensions = [], decoder = None , * args , ** fields ):
142
+ self .extensions = extensions
143
+ self .decoder = decoder
144
+ super ().__init__ (_pkt = pkt , * args , ** fields )
145
+
146
+ def extract_padding (self , s ):
147
+ return '' , s
148
+
149
+ @classmethod
150
+ def tcp_reassemble (cls , data , metadata , session ):
151
+ # data = the reassembled data from the same request/flow
152
+ # metadata = empty dictionary, that can be used to store data
153
+ # during TCP reassembly
154
+ # session = a dictionary proper to the bidirectional TCP session,
155
+ # that can be used to store anything
156
+ # [...]
157
+ # If the packet is available, return it. Otherwise don't.
158
+ # Whenever you return a packet, the buffer will be discarded.
159
+
160
+
161
+ HANDSHAKE_STATE_CLIENT_OPEN = 0
162
+ HANDSHAKE_STATE_SERVER_OPEN = 1
163
+ HANDSHAKE_STATE_OPEN = 2
164
+
165
+ if "handshake-state" not in session :
166
+ session ["handshake-state" ] = HANDSHAKE_STATE_CLIENT_OPEN
167
+
168
+ if "extensions" not in session :
169
+ session ["extensions" ] = {}
170
+
171
+
172
+ if session ["handshake-state" ] == HANDSHAKE_STATE_CLIENT_OPEN :
173
+ ht = HTTPRequest (data )
174
+
175
+ if ht .Method != b"GET" :
176
+ return None
177
+
178
+ if not ht .Upgrade or ht .Upgrade .lower () != b"websocket" :
179
+ return None
180
+
181
+ if b"Sec-WebSocket-Key" not in ht .Unknown_Headers :
182
+ return None
183
+
184
+
185
+ session ["handshake-key" ] = ht .Unknown_Headers [b"Sec-WebSocket-Key" ]
186
+
187
+ if "original" in metadata :
188
+ session ["server-port" ] = metadata ["original" ][TCP ].dport
189
+ else :
190
+ print ("No original packet" )
191
+
192
+ session ["handshake-state" ] = HANDSHAKE_STATE_SERVER_OPEN
193
+
194
+ return ht
195
+
196
+ elif session ["handshake-state" ] == HANDSHAKE_STATE_SERVER_OPEN :
197
+ ht = HTTPResponse (data )
198
+
199
+ if not ht .Upgrade .lower () == b"websocket" :
200
+ return None
201
+
202
+ # Verify key-accept handshake:
203
+ correct_accept = base64 .b64encode (sha1 (session ["handshake-key" ] + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" .encode ()).digest ())
204
+ if ht .Unknown_Headers [b"Sec-WebSocket-Accept" ] != correct_accept :
205
+ #TODO handle or Logg wrong accept key
206
+ pass
207
+
208
+ if b"Sec-WebSocket-Extensions" in ht .Unknown_Headers :
209
+ session ["extensions" ] = {}
210
+ for extension in ht .Unknown_Headers [b"Sec-WebSocket-Extensions" ].decode ().strip ().split (";" ):
211
+ key_value_pair = extension .split ("=" , 1 ) + [None ]
212
+ session ["extensions" ][key_value_pair [0 ].strip ()] = key_value_pair [1 ]
213
+
214
+ if "permessage-deflate" in session ["extensions" ]:
215
+ def create_decompressor (window_bits ):
216
+ decoder = zlib .decompressobj (wbits = - window_bits )
217
+ def decomp (data ):
218
+ nonlocal decoder
219
+ return decoder .decompress (data , 0 )
220
+
221
+ def reset ():
222
+ nonlocal decoder
223
+ nonlocal window_bits
224
+ decoder = zlib .decompressobj (wbits = - window_bits )
225
+
226
+ return (decomp , reset )
227
+
228
+ # Default values
229
+ client_wb = 12
230
+ server_wb = 15
231
+
232
+ # Check for new values in extensions header
233
+ if "client_max_window_bits" in session ["extensions" ]:
234
+ client_wb = int (session ["extensions" ]["client_max_window_bits" ])
235
+
236
+ if "server_max_window_bits" in session ["extensions" ]:
237
+ server_wb = int (session ["extensions" ]["server_max_window_bits" ])
238
+
239
+
240
+ session ["server-decoder" ] = create_decompressor (client_wb )
241
+ session ["client-decoder" ] = create_decompressor (server_wb )
242
+
243
+
244
+ session ["handshake-state" ] = HANDSHAKE_STATE_OPEN
245
+
246
+ return ht
247
+
248
+
249
+ # Handshake is done:
250
+ if "original" not in metadata :
251
+ return
252
+
253
+ if "permessage-deflate" in session ["extensions" ]:
254
+ is_server = True if metadata ["original" ][TCP ].sport == session ["server-port" ] else False
255
+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ], decoder = session ["server-decoder" ] if is_server else session ["client-decoder" ])
256
+ return ws
257
+ else :
258
+ ws = WebSocket (bytes (data ), extensions = session ["extensions" ])
259
+ return ws
0 commit comments