diff --git a/distributed/comm/tcp.py b/distributed/comm/tcp.py index 774d606e2a1..c741aba83e9 100644 --- a/distributed/comm/tcp.py +++ b/distributed/comm/tcp.py @@ -44,14 +44,14 @@ logger = logging.getLogger(__name__) -# Workaround for OpenSSL 1.0.2. -# Can drop with OpenSSL 1.1.1 used by Python 3.10+. -# ref: https://bugs.python.org/issue42853 -if sys.version_info < (3, 10): - OPENSSL_MAX_CHUNKSIZE = 256 ** ctypes.sizeof(ctypes.c_int) // 2 - 1 -else: - OPENSSL_MAX_CHUNKSIZE = 256 ** ctypes.sizeof(ctypes.c_size_t) - 1 - +# We must not load more than this into a buffer at a time +# It's currently unclear why that is +# see +# - https://github.com/dask/distributed/pull/5854 +# - https://bugs.python.org/issue42853 +# - https://github.com/dask/distributed/pull/8507 + +C_INT_MAX = 256 ** ctypes.sizeof(ctypes.c_int) // 2 - 1 MAX_BUFFER_SIZE = MEMORY_LIMIT / 2 @@ -286,8 +286,8 @@ async def write(self, msg, serializers=None, on_error="message"): 2, range( 0, - each_frame_nbytes + OPENSSL_MAX_CHUNKSIZE, - OPENSSL_MAX_CHUNKSIZE, + each_frame_nbytes + C_INT_MAX, + C_INT_MAX, ), ): chunk = each_frame[i:j] @@ -360,7 +360,7 @@ async def read_bytes_rw(stream: IOStream, n: int) -> memoryview: for i, j in sliding_window( 2, - range(0, n + OPENSSL_MAX_CHUNKSIZE, OPENSSL_MAX_CHUNKSIZE), + range(0, n + C_INT_MAX, C_INT_MAX), ): chunk = buf[i:j] actual = await stream.read_into(chunk) # type: ignore[arg-type] @@ -432,7 +432,8 @@ class TLS(TCP): A TLS-specific version of TCP. """ - max_shard_size = min(OPENSSL_MAX_CHUNKSIZE, TCP.max_shard_size) + # Workaround for OpenSSL 1.0.2 (can drop with OpenSSL 1.1.1) + max_shard_size = min(C_INT_MAX, TCP.max_shard_size) def _read_extra(self): TCP._read_extra(self) diff --git a/distributed/protocol/tests/test_protocol.py b/distributed/protocol/tests/test_protocol.py index 630239c1adc..c9db26fb79c 100644 --- a/distributed/protocol/tests/test_protocol.py +++ b/distributed/protocol/tests/test_protocol.py @@ -208,3 +208,30 @@ def test_fallback_to_pickle(): assert L[0].count(b"__Pickled__") == 1 assert L[0].count(b"__Serialized__") == 1 assert loads(L) == {np.int64(1): {2: "a"}, 3: ("b", "c"), 4: "d"} + + +@pytest.mark.slow +@pytest.mark.parametrize("typ", [bytes, str, "ext"]) +def test_large_payload(typ): + """See also: test_core.py::test_large_payload""" + critical_size = 2**31 + 1 # >2 GiB + if typ == bytes: + large_payload = critical_size * b"0" + expected = large_payload + elif typ == str: + large_payload = critical_size * "0" + expected = large_payload + # Testing array and map dtypes is practically not possible since we'd have + # to create an actual list or dict object of critical size (i.e. not the + # content but the container itself). These are so large that msgpack is + # running forever + # elif typ == "array": + # large_payload = [b"0"] * critical_size + # expected = tuple(large_payload) + # elif typ == "map": + # large_payload = {x: b"0" for x in range(critical_size)} + # expected = large_payload + elif typ == "ext": + large_payload = msgpack.ExtType(1, b"0" * critical_size) + expected = large_payload + assert loads(dumps(large_payload)) == expected diff --git a/distributed/protocol/utils.py b/distributed/protocol/utils.py index f7fa7a9984a..0777b8f9e0d 100644 --- a/distributed/protocol/utils.py +++ b/distributed/protocol/utils.py @@ -12,9 +12,7 @@ BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes(dask.config.get("distributed.comm.shard")) -msgpack_opts = { - ("max_%s_len" % x): 2**31 - 1 for x in ["str", "bin", "array", "map", "ext"] -} +msgpack_opts = {} msgpack_opts["strict_map_key"] = False msgpack_opts["raw"] = False diff --git a/distributed/tests/test_core.py b/distributed/tests/test_core.py index 85a502bafee..93bb18f16d3 100644 --- a/distributed/tests/test_core.py +++ b/distributed/tests/test_core.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import logging import os import random import socket @@ -1481,3 +1482,27 @@ def sync_handler(val): assert ledger == list(range(n)) finally: await comm.close() + + +@pytest.mark.slow +@gen_test(timeout=180) +async def test_large_payload(caplog): + """See also: protocol/tests/test_protocol.py::test_large_payload""" + critical_size = 2**31 + 1 # >2 GiB + data = b"0" * critical_size + + async with Server({"echo": echo_serialize}) as server: + await server.listen(0) + comm = await connect(server.address) + + # FIXME https://github.com/dask/distributed/issues/8465 + # At debug level, messages are dumped into the log. By default, pytest captures + # all logs, which would make this test extremely expensive to run. + with caplog.at_level(logging.INFO, logger="distributed.core"): + # Note: if we wrap data in to_serialize, it will be sent as a buffer, which + # is not encoded by msgpack. + await comm.write({"op": "echo", "x": data}) + response = await comm.read() + + assert response["result"] == data + await comm.close()