Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Distributed checkpointing API for dask collections #8483

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 247 additions & 0 deletions distributed/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
from __future__ import annotations

import contextlib
import glob
import os
import pickle
from collections import defaultdict
from importlib import import_module

import dask.dataframe as dd
from dask.blockwise import BlockIndex
from dask.utils import typename

from distributed import default_client, wait
from distributed.protocol import dask_deserialize, dask_serialize


class Handler:
"""Base class for format-specific checkpointing handlers

A ``Handler`` object will be responsible for a single partition.
"""

fmt: None | str = None # General format label

def __init__(self, path, backend, index, **kwargs):
self.path = path
self.backend = backend
self.index = index
self.kwargs = kwargs

@classmethod
def clean(cls, dirpath):
"""Clean the target directory"""
import shutil

if os.path.isdir(dirpath):
with contextlib.suppress(FileNotFoundError):
shutil.rmtree(dirpath)

@classmethod
def prepare(cls, dirpath):
"""Create the target directory"""
os.makedirs(dirpath, exist_ok=True)

@classmethod
def save(cls, part, path, index):
"""Persist the target partition to disk"""
raise NotImplementedError() # Logic depends on format

@classmethod
def get_indices(cls, path):
"""Return set of local indices"""
# Assume file-name is something like: <name>.<index>.<fmt>
return {int(fn.split(".")[-2]) for fn in glob.glob(path + f"/*.{cls.fmt}")}

def load(self):
"""Collect the saved partition"""
raise NotImplementedError() # Logic depends on format


@dask_serialize.register(Handler)
def _serialize_unloaded(obj):
# Make sure we read the partition into memory if
# this partition is moved to a different worker
return None, [pickle.dumps(obj.load())]


@dask_deserialize.register(Handler)
def _deserialize_unloaded(header, frames):
# Deserializing a `Handler` object returns the wrapped data
return pickle.loads(frames[0])


class ParquetHandler(Handler):
"""Parquet-specific checkpointing handler for DataFrame collections"""

fmt = "parquet"

@classmethod
def save(cls, part, path, index):
fn = f"{path}/part.{index[0]}.parquet"
part.to_parquet(fn)
return index

def load(self):
lib = import_module(self.backend)
fn = glob.glob(f"{self.path}/*.{self.index}.parquet")
return lib.read_parquet(fn, **self.kwargs)


class BaseCheckpoint:
"""Checkpoint a Dask collection on disk

The storage location does not need to be shared between workers.
"""

@classmethod
def create(cls, *args, **kwargs):
"""Create a new Checkpoint object"""
raise NotImplementedError()

def load(self):
"""Load a checkpointed collection

Note that this will not immediately persist the partitions
in memory. Rather, it will output a lazy Dask collection.
"""
raise NotImplementedError()

def clean(self):
"""Clean up this checkpoint"""
raise NotImplementedError()


class DataFrameCheckpoint(BaseCheckpoint):
"""Checkpoint a Dask DataFrame on disk"""

def __init__(
self,
npartitions,
meta,
handler,
path,
load_kwargs,
):
self.npartitions = npartitions
self.meta = meta
self.backend = typename(meta).partition(".")[0]
self.handler = handler
self.path = path
self.load_kwargs = load_kwargs or {}
self._valid = True

def __repr__(self):
path = self.path
fmt = self.handler.fmt
return f"DataFrameCheckpoint<path={path}, format={fmt}>"

@classmethod
def create(
cls,
df,
path,
format="parquet",
overwrite=True,
compute_kwargs=None,
load_kwargs=None,
**save_kwargs,
):
# Get handler
if format == "parquet":
handler = ParquetHandler
else:
# Only parquet supported for now
raise NotImplementedError()

client = default_client()

if overwrite:
wait(client.run(handler.clean, path))
wait(client.run(handler.prepare, path))

meta = df._meta.copy()
df.map_partitions(
handler.save,
path,
BlockIndex((df.npartitions,)),
meta=meta,
enforce_metadata=False,
**save_kwargs,
).compute(**(compute_kwargs or {}))

return cls(
df.npartitions,
meta,
handler,
path,
load_kwargs,
)

def load(self):
if not self._valid:
raise RuntimeError("This checkpoint is no longer valid")

#
# Get client and check workers
#
client = default_client()

#
# Find out which partition indices are stored on each worker
#
worker_indices = client.run(self.handler.get_indices, self.path)
summary = defaultdict(list)
for worker, indices in worker_indices.items():
for index in indices:
summary[index].append(worker)

# Check partition count
npartitions_found = len(summary)
if len(summary) != self.npartitions:
raise RuntimeError(
f"Expected {self.npartitions} partitions. "
f"Found {npartitions_found}."
)

#
# Convert each checkpointed partition to a `Handler` object
#
assignments = {}
futures = []
for i, (worker, indices) in enumerate(summary.items()):
assignments[worker] = indices[i % len(indices)]
futures.append(
client.submit(
self.handler,
self.path,
self.backend,
i,
workers=[assignments[i]],
**self.load_kwargs,
)
)
wait(futures)

#
# Crate a new collection from the delayed `Handler` objects
#
meta = self.meta
return dd.from_delayed(futures, meta=meta, verify_meta=False).map_partitions(
self._load_partition,
meta=meta,
)

@staticmethod
def _load_partition(obj):
# Load a checkpointed partition.
# Used by DataFrameCheckpoint.load
if isinstance(obj, Handler):
return obj.load()
return obj

def clean(self):
client = default_client()
wait(client.run(self.handler.clean, self.path))
self._valid = False
14 changes: 14 additions & 0 deletions distributed/tests/test_dask_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,17 @@ def test_tuple_futures_arg(client, typ):
),
)
dd.assert_eq(df2.result().iloc[:0], make_time_dataframe().iloc[:0])


def test_dataframe_checkpoint(client, tmp_path):
from distributed.checkpoint import DataFrameCheckpoint

df = make_time_dataframe()
ddf = dd.from_pandas(df, npartitions=10)

ckpt = DataFrameCheckpoint.create(ddf, str(tmp_path))
client.cancel(ddf)
del ddf

# Must use distributed scheduler to compute
dd.assert_eq(df, ckpt.load(), scheduler=None)