-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsyft_rpc_client.py
executable file
Β·230 lines (195 loc) Β· 8.8 KB
/
syft_rpc_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from __future__ import annotations
import threading
import time
import argparse
import sys
from datetime import datetime, timezone
from typing import List, Optional, Dict, Any, Callable, Type
from loguru import logger
from pydantic import BaseModel, Field
from syft_event import SyftEvents
from syft_event.types import Request
from syft_core import Client
from syft_rpc import rpc
class SyftRPCClient:
"""A generic Syft RPC client that can be extended for various applications.
This template demonstrates:
1. Background server to handle incoming requests
2. Client methods to send requests to other datasites
3. Discovery of available datasites
4. Error handling and resource management
Extends this class to create your own custom RPC applications.
"""
def __init__(self,
config_path: Optional[str] = None,
app_name: str = "pingpong",
endpoint: str = "/ping",
request_model: Type[BaseModel] = None,
response_model: Type[BaseModel] = None,
start_server: bool = True):
"""Initialize the Syft RPC client."""
self.client = Client.load(config_path)
self.app_name = app_name
self.endpoint = endpoint
self.request_model = request_model
self.response_model = response_model
self.stop_event = threading.Event()
self.server_thread = None
logger.info(f"π Connected as: {self.client.email}")
# Start server in background thread if requested
if start_server:
self._start_server()
def _start_server(self):
"""Start the RPC server in the background."""
self.stop_event.clear()
self.server_thread = threading.Thread(
target=self._run_server,
daemon=True
)
self.server_thread.start()
logger.info(f"π Server started for {self.client.email}")
def _run_server(self):
"""Run the RPC server in a background thread."""
box = self._create_server()
logger.info(f"π SERVER: Running {self.app_name} server as {self.client.email}")
# Use a wrapper function that captures the self.request_model
request_model = self.request_model
response_model = self.response_model
# Register the handler for the endpoint with the correct type annotation
@box.on_request(self.endpoint)
def request_handler(request_data: dict, ctx: Request) -> dict:
# Convert the incoming data to the proper model type
if not isinstance(request_data, request_model):
try:
if isinstance(request_data, dict):
request_data = request_model(**request_data)
else:
request_data = request_model.model_validate(request_data)
except Exception as e:
logger.error(f"Failed to convert request to {request_model.__name__}: {e}")
# Call the handler and convert response to dict for proper serialization
response = self._handle_request(request_data, ctx, box)
# Ensure proper serialization with datetime objects
if hasattr(response, "model_dump"):
# Use Pydantic's model_dump with exclude_none and serialize datetime as ISO format
return response.model_dump(exclude_none=True, mode='json')
elif hasattr(response, "dict"):
# For older Pydantic versions
return response.dict(exclude_none=True, json_encoders={datetime: lambda dt: dt.isoformat()})
else:
return response
try:
logger.info(f"π‘ SERVER: Listening for requests at {box.app_rpc_dir}")
# Start the server with error handling
try:
box.start()
except RuntimeError as e:
if "already scheduled" in str(e):
logger.warning(f"Watch already exists: {e}. Continuing anyway.")
else:
raise
# Process requests in a loop
while not self.stop_event.is_set():
box.process_pending_requests()
time.sleep(0.1)
except Exception as e:
logger.error(f"β SERVER ERROR: {e}")
finally:
try:
box.stop()
except Exception as e:
logger.error(f"Error stopping server: {e}")
def run_standalone_server(self):
"""Run the server in the current thread (for standalone mode)."""
try:
logger.info(f"π Starting standalone {self.app_name} server as {self.client.email}")
self._run_server()
except KeyboardInterrupt:
logger.info("Server interrupted by user. Shutting down...")
finally:
self.close()
def _create_server(self):
"""Create and return the SyftEvents server."""
return SyftEvents(self.app_name, client=self.client)
def _handle_request(self, request_data: BaseModel, ctx: Request, box) -> BaseModel:
"""Handle incoming requests. Override this in your subclass."""
logger.info(f"π RECEIVED: Request - {request_data}")
return self.response_model(
msg=f"Response from {box.client.email}",
ts=datetime.now(timezone.utc),
)
def send_request(self, to_email: str, request_data: Optional[BaseModel] = None) -> Optional[BaseModel]:
"""Send a request to the specified datasite.
Args:
to_email: The email/datasite to send to
request_data: Optional custom request data (uses default if None)
Returns:
Response model if successful, None otherwise
"""
if not self._valid_datasite(to_email):
logger.error(f"Invalid datasite: {to_email}")
logger.info("Available datasites:")
for d in self.list_datasites():
logger.info(f" - {d}")
return None
# Use default request if none provided
if request_data is None:
request_data = self.request_model(
msg=f"Hello from {self.client.email}!",
ts=datetime.now(timezone.utc)
)
logger.info(f"π€ SENDING: Request to {to_email}")
start = time.time()
future = rpc.send(
url=f"syft://{to_email}/api_data/{self.app_name}/rpc{self.endpoint}",
body=request_data,
expiry="5m",
cache=True,
client=self.client,
)
try:
response = future.wait(timeout=60)
response.raise_for_status()
model_response = response.model(self.response_model)
elapsed = time.time() - start
logger.info(f"π₯ RECEIVED: Response from {to_email}. Time: {elapsed:.2f}s")
return model_response
except Exception as e:
logger.error(f"β CLIENT ERROR: {e}")
return None
def list_datasites(self) -> List[str]:
"""Get a list of available datasites.
Returns:
List of datasite emails
"""
return sorted([ds.name for ds in self.client.datasites.glob("*") if "@" in ds.name])
def list_available_servers(self) -> List[str]:
"""Get a list of datasites running this app's server.
Returns:
List of datasite emails with active servers
"""
available_servers = []
for ds in self.list_datasites():
# Check if the datasite has the RPC endpoint published
rpc_path = self.client.datasites / ds / "api_data" / self.app_name / "rpc" / "rpc.schema.json"
if rpc_path.exists():
available_servers.append(ds)
return available_servers
def _valid_datasite(self, ds: str) -> bool:
"""Check if the given datasite is valid."""
return ds in self.list_datasites()
def close(self):
"""Shut down the client."""
logger.info(f"π Shutting down {self.app_name} client...")
self.stop_event.set()
if self.server_thread:
self.server_thread.join(timeout=2)
# Add command-line interface for standalone server mode
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a standalone SyftRPC server")
parser.add_argument("--config", help="Path to config.json file")
args = parser.parse_args()
# Create a minimal client for the standalone server
# Subclasses should implement their own __main__ block with specific models
client = SyftRPCClient(config_path=args.config, start_server=False)
client.run_standalone_server()