15
15
from typing import TYPE_CHECKING , Any , Callable , ClassVar
16
16
17
17
from dask .typing import Key
18
- from dask .utils import funcname , tmpfile
18
+ from dask .utils import _deprecated_kwarg , funcname , tmpfile
19
19
20
20
from distributed .protocol .pickle import dumps
21
21
@@ -896,36 +896,46 @@ async def setup(self, nanny):
896
896
nanny .env .update (self .environ )
897
897
898
898
899
- class UploadDirectory (NannyPlugin ):
900
- """A NannyPlugin to upload a local file to workers.
899
+ UPLOAD_DIRECTORY_MODES = ["all" , "scheduler" , "workers" ]
900
+
901
+
902
+ class UploadDirectory (SchedulerPlugin ):
903
+ """Scheduler to upload a local directory to the cluster.
901
904
902
905
Parameters
903
906
----------
904
- path: str
905
- A path to the directory to upload
907
+ path:
908
+ Path to the directory to upload
909
+ scheduler:
910
+ Whether to upload the directory to the scheduler
906
911
907
912
Examples
908
913
--------
909
914
>>> from distributed.diagnostics.plugin import UploadDirectory
910
- >>> client.register_plugin(UploadDirectory("/path/to/directory"), nanny=True ) # doctest: +SKIP
915
+ >>> client.register_plugin(UploadDirectory("/path/to/directory")) # doctest: +SKIP
911
916
"""
912
917
918
+ @_deprecated_kwarg ("restart" , "restart_workers" )
913
919
def __init__ (
914
920
self ,
915
921
path ,
916
- restart = False ,
922
+ restart_workers = False ,
917
923
update_path = False ,
918
924
skip_words = (".git" , ".github" , ".pytest_cache" , "tests" , "docs" ),
919
925
skip = (lambda fn : os .path .splitext (fn )[1 ] == ".pyc" ,),
926
+ mode = "workers" ,
920
927
):
921
- """
922
- Initialize the plugin by reading in the data from the given file.
923
- """
924
928
path = os .path .expanduser (path )
925
929
self .path = os .path .split (path )[- 1 ]
926
- self .restart = restart
930
+ self .restart_workers = restart_workers
927
931
self .update_path = update_path
928
932
933
+ if mode not in UPLOAD_DIRECTORY_MODES :
934
+ raise ValueError (
935
+ f"{ mode = } not supported, expected one of { UPLOAD_DIRECTORY_MODES } "
936
+ )
937
+ self .mode = mode
938
+
929
939
self .name = "upload-directory-" + os .path .split (path )[- 1 ]
930
940
931
941
with tmpfile (extension = "zip" ) as fn :
@@ -944,26 +954,67 @@ def __init__(
944
954
)
945
955
z .write (filename , archive_name )
946
956
947
- with open (fn , "rb" ) as f :
957
+ with open (fn , mode = "rb" ) as f :
948
958
self .data = f .read ()
949
959
950
- async def setup (self , nanny ):
951
- fn = os .path .join (nanny .local_directory , f"tmp-{ uuid .uuid4 ()} .zip" )
952
- with open (fn , "wb" ) as f :
953
- f .write (self .data )
960
+ async def start (self , scheduler ):
961
+ from distributed .core import clean_exception
962
+ from distributed .protocol .serialize import Serialized , deserialize
963
+
964
+ if self .mode in ("all" , "scheduler" ):
965
+ _extract_data (
966
+ scheduler .local_directory , self .path , self .data , self .update_graph
967
+ )
968
+
969
+ if self .mode in ("all" , "workers" ):
970
+ nanny_plugin = _UploadDirectoryNannyPlugin (
971
+ self .path , self .data , self .restart_workers , self .update_path , self .name
972
+ )
973
+ responses = await scheduler .register_nanny_plugin (
974
+ comm = None ,
975
+ plugin = dumps (nanny_plugin ),
976
+ name = self .name ,
977
+ idempotent = False ,
978
+ )
979
+
980
+ for response in responses .values ():
981
+ if response ["status" ] == "error" :
982
+ response = {
983
+ k : deserialize (v .header , v .frames )
984
+ for k , v in response .items ()
985
+ if isinstance (v , Serialized )
986
+ }
987
+ _ , exc , tb = clean_exception (** response )
988
+ raise exc .with_traceback (tb )
989
+
990
+
991
+ class _UploadDirectoryNannyPlugin (NannyPlugin ):
992
+ def __init__ (self , path , data , restart , update_path , name ):
993
+ self .path = path
994
+ self .data = data
995
+ self .name = name
996
+ self .restart = restart
997
+ self .update_path = update_path
998
+
999
+ def setup (self , nanny ):
1000
+ _extract_data (nanny .local_directory , self .path , self .data , self .update_path )
1001
+
1002
+
1003
+ def _extract_data (base_path , path , data , update_path ):
1004
+ with tmpfile (extension = "zip" ) as fn :
1005
+ with open (fn , mode = "wb" ) as f :
1006
+ f .write (data )
954
1007
955
1008
import zipfile
956
1009
957
1010
with zipfile .ZipFile (fn ) as z :
958
- z .extractall (path = nanny . local_directory )
1011
+ z .extractall (path = base_path )
959
1012
960
- if self . update_path :
961
- path = os .path .join (nanny . local_directory , self . path )
1013
+ if update_path :
1014
+ path = os .path .join (base_path , path )
962
1015
if path not in sys .path :
963
1016
sys .path .insert (0 , path )
964
1017
965
- os .remove (fn )
966
-
967
1018
968
1019
class forward_stream :
969
1020
def __init__ (self , stream , worker ):
0 commit comments