Skip to content

Commit 22f037f

Browse files
authored
Handle matrix subclass serialization (#8480)
1 parent a170014 commit 22f037f

File tree

3 files changed

+11
-1
lines changed

3 files changed

+11
-1
lines changed

distributed/protocol/numpy.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ def itemsize(dt):
2626
return result
2727

2828

29+
@dask_serialize.register(np.matrix)
30+
def serialize_numpy_matrix(x, context=None):
31+
header, frames = serialize_numpy_ndarray(x)
32+
header["matrix"] = True
33+
return header, frames
34+
35+
2936
@dask_serialize.register(np.ndarray)
3037
def serialize_numpy_ndarray(x, context=None):
3138
if x.dtype.hasobject or (x.dtype.flags & np_core.multiarray.LIST_PICKLE):
@@ -151,7 +158,8 @@ def deserialize_numpy_ndarray(header, frames):
151158
# buffers the decompressed output is deep-copied beforehand into a
152159
# bytearray in order to merge it.
153160
x = np.require(x, requirements=["W"])
154-
161+
if header.get("matrix"):
162+
x = np.asmatrix(x)
155163
return x
156164

157165

distributed/protocol/tests/test_numpy.py

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_serialize():
7777
np.arange(12)[::2], # non-contiguous array
7878
np.ones(shape=(5, 6)).astype(dtype=[("total", "<f8"), ("n", "<f8")]),
7979
np.broadcast_to(np.arange(3), shape=(10, 3)), # zero-strided array
80+
np.matrix([[1, 2], [3, 4]]),
8081
],
8182
)
8283
def test_dumps_serialize_numpy(x):

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ filterwarnings = [
151151
'''ignore:datetime\.datetime\.utc(fromtimestamp|now)\(\) is deprecated and scheduled for removal in a future version.*:DeprecationWarning:dateutil''',
152152
# https://github.com/dask/dask/pull/10622
153153
'''ignore:Minimal version of pyarrow will soon be increased to 14.0.1''',
154+
'''ignore:the matrix subclass is not the recommended way''',
154155
]
155156
minversion = "6"
156157
markers = [

0 commit comments

Comments
 (0)