Skip to content

Commit 69f7982

Browse files
authored
[DPE-6296] Pyright fixes + structured_config additions + break down of actions.py (#13)
1 parent 14d4daa commit 69f7982

21 files changed

+756
-591
lines changed

poetry.lock

+104-103
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/benchmark/base_charm.py

+19-148
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,26 @@
1616

1717
import logging
1818
import subprocess
19-
from abc import ABC, abstractmethod
2019
from typing import Any
2120

22-
import ops
21+
from charms.data_platform_libs.v0.data_models import TypedCharmBase
2322
from charms.grafana_agent.v0.cos_agent import COSAgentProvider
2423
from ops.charm import CharmEvents
2524
from ops.framework import EventBase, EventSource
2625
from ops.model import BlockedStatus
2726

2827
from benchmark.core.models import DPBenchmarkLifecycleState
2928
from benchmark.core.pebble_workload_base import DPBenchmarkPebbleWorkloadBase
29+
from benchmark.core.structured_config import BenchmarkCharmConfig
3030
from benchmark.core.systemd_workload_base import DPBenchmarkSystemdWorkloadBase
3131
from benchmark.core.workload_base import WorkloadBase
32+
from benchmark.events.actions import ActionsHandler
3233
from benchmark.events.db import DatabaseRelationHandler
3334
from benchmark.events.peer import PeerRelationHandler
3435
from benchmark.literals import (
3536
COS_AGENT_RELATION,
3637
METRICS_PORT,
3738
PEER_RELATION,
38-
DPBenchmarkLifecycleTransition,
3939
DPBenchmarkMissingOptionsError,
4040
)
4141
from benchmark.managers.config import ConfigManager
@@ -70,34 +70,22 @@ def workload_build(workload_params_template: str) -> WorkloadBase:
7070
return DPBenchmarkSystemdWorkloadBase(workload_params_template)
7171

7272

73-
class DPBenchmarkCharmBase(ops.CharmBase, ABC):
73+
class DPBenchmarkCharmBase(TypedCharmBase[BenchmarkCharmConfig]):
7474
"""The base benchmark class."""
7575

76-
on = DPBenchmarkEvents() # pyright: ignore [reportGeneralTypeIssues]
76+
on = DPBenchmarkEvents() # pyright: ignore [reportAssignmentType]
7777

7878
RESOURCE_DEB_NAME = "benchmark-deb"
7979
workload_params_template = ""
8080

81+
config_type = BenchmarkCharmConfig
82+
8183
def __init__(self, *args, db_relation_name: str, workload: WorkloadBase | None = None):
8284
super().__init__(*args)
8385
self.framework.observe(self.on.install, self._on_install)
8486
self.framework.observe(self.on.config_changed, self._on_config_changed)
8587
self.framework.observe(self.on.update_status, self._on_update_status)
8688

87-
self.framework.observe(self.on.prepare_action, self.on_prepare_action)
88-
self.framework.observe(self.on.run_action, self.on_run_action)
89-
self.framework.observe(self.on.stop_action, self.on_stop_action)
90-
self.framework.observe(self.on.cleanup_action, self.on_clean_action)
91-
92-
self.framework.observe(
93-
self.on.check_upload,
94-
self._on_check_upload,
95-
)
96-
self.framework.observe(
97-
self.on.check_collect,
98-
self._on_check_collect,
99-
)
100-
10189
self.database = DatabaseRelationHandler(self, db_relation_name)
10290
self.peers = PeerRelationHandler(self, PEER_RELATION)
10391
self.framework.observe(self.database.on.db_config_update, self._on_config_changed)
@@ -119,8 +107,8 @@ def __init__(self, *args, db_relation_name: str, workload: WorkloadBase | None =
119107

120108
self.config_manager = ConfigManager(
121109
workload=self.workload,
122-
database=self.database.state,
123-
peer=self.peers.peers(),
110+
database_state=self.database.state,
111+
peers=self.peers.peers(),
124112
config=self.config,
125113
labels=self.labels,
126114
)
@@ -129,11 +117,7 @@ def __init__(self, *args, db_relation_name: str, workload: WorkloadBase | None =
129117
self.peers.this_unit(),
130118
self.config_manager,
131119
)
132-
133-
@abstractmethod
134-
def supported_workloads(self) -> list[str]:
135-
"""List of supported workloads."""
136-
...
120+
self.actions = ActionsHandler(self)
137121

138122
###########################################################################
139123
#
@@ -146,28 +130,6 @@ def _on_install(self, event: EventBase) -> None:
146130
self.workload.install()
147131
self.peers.state.lifecycle = DPBenchmarkLifecycleState.UNSET
148132

149-
def _on_check_collect(self, event: EventBase) -> None:
150-
"""Check if the upload is finished."""
151-
if self.config_manager.is_collecting():
152-
# Nothing to do, upload is still in progress
153-
event.defer()
154-
return
155-
156-
if self.unit.is_leader():
157-
self.peers.state.set(DPBenchmarkLifecycleState.UPLOADING)
158-
# Raise we are running an upload and we will check the status later
159-
self.on.check_upload.emit()
160-
return
161-
self.peers.state.set(DPBenchmarkLifecycleState.FINISHED)
162-
163-
def _on_check_upload(self, event: EventBase) -> None:
164-
"""Check if the upload is finished."""
165-
if self.config_manager.is_uploading():
166-
# Nothing to do, upload is still in progress
167-
event.defer()
168-
return
169-
self.peers.state.lifecycle = DPBenchmarkLifecycleState.FINISHED
170-
171133
def _on_update_status(self, event: EventBase | None = None) -> None:
172134
"""Set status for the operator and finishes the service.
173135
@@ -176,34 +138,20 @@ def _on_update_status(self, event: EventBase | None = None) -> None:
176138
benchmark service and the benchmark status.
177139
"""
178140
try:
179-
status = self.database.state.get()
141+
status = self.database.state.model()
180142
except DPBenchmarkMissingOptionsError as e:
181143
self.unit.status = BlockedStatus(str(e))
182144
return
183145
if not status:
184146
self.unit.status = BlockedStatus("No database relation available")
185147
return
186148

187-
# We need to narrow the options of workload_name to the supported ones
188-
if self.config.get("workload_name") not in self.supported_workloads():
189-
self.unit.status = BlockedStatus(
190-
f"Unsupported workload: {self.config.get('workload_name')}"
191-
)
192-
return
193-
194149
# Now, let's check if we need to update our lifecycle position
195-
self._update_state()
150+
self.update_state()
196151
self.unit.status = self.lifecycle.status
197152

198153
def _on_config_changed(self, event: EventBase) -> None:
199154
"""Config changed event."""
200-
# We need to narrow the options of workload_name to the supported ones
201-
if self.config.get("workload_name") not in self.supported_workloads():
202-
self.unit.status = BlockedStatus(
203-
f"Unsupported workload: {self.config.get('workload_name')}"
204-
)
205-
return
206-
207155
if not self.config_manager.is_prepared():
208156
# nothing to do: set the status and leave
209157
self._on_update_status()
@@ -228,88 +176,6 @@ def scrape_config(self) -> list[dict[str, Any]]:
228176
}
229177
]
230178

231-
###########################################################################
232-
#
233-
# Action and Lifecycle Handlers
234-
#
235-
###########################################################################
236-
237-
def _preflight_checks(self) -> bool:
238-
"""Check if we have the necessary relations."""
239-
if len(self.peers.units()) > 0 and not bool(self.peers.state.get()):
240-
return False
241-
try:
242-
return bool(self.database.state.get())
243-
except DPBenchmarkMissingOptionsError:
244-
return False
245-
246-
def on_prepare_action(self, event: EventBase) -> None:
247-
"""Process the prepare action."""
248-
if not self._preflight_checks():
249-
event.fail("Missing DB or S3 relations")
250-
return
251-
252-
if not (state := self.lifecycle.next(DPBenchmarkLifecycleTransition.PREPARE)):
253-
event.fail("Failed to prepare the benchmark: already done")
254-
return
255-
256-
if state != DPBenchmarkLifecycleState.PREPARING:
257-
event.fail(
258-
"Another peer is already in prepare state. Wait or call clean action to reset."
259-
)
260-
return
261-
262-
# We process the special case of PREPARE, as explained in lifecycle.make_transition()
263-
if not self.config_manager.prepare():
264-
event.fail("Failed to prepare the benchmark")
265-
return
266-
267-
self.lifecycle.make_transition(state)
268-
self.unit.status = self.lifecycle.status
269-
event.set_results({"message": "Benchmark is being prepared"})
270-
271-
def on_run_action(self, event: EventBase) -> None:
272-
"""Process the run action."""
273-
if not self._preflight_checks():
274-
event.fail("Missing DB or S3 relations")
275-
return
276-
277-
if not self._process_action_transition(DPBenchmarkLifecycleTransition.RUN):
278-
event.fail("Failed to run the benchmark")
279-
event.set_results({"message": "Benchmark has started"})
280-
281-
def on_stop_action(self, event: EventBase) -> None:
282-
"""Process the stop action."""
283-
if not self._preflight_checks():
284-
event.fail("Missing DB or S3 relations")
285-
return
286-
287-
if not self._process_action_transition(DPBenchmarkLifecycleTransition.STOP):
288-
event.fail("Failed to stop the benchmark")
289-
event.set_results({"message": "Benchmark has stopped"})
290-
291-
def on_clean_action(self, event: EventBase) -> None:
292-
"""Process the clean action."""
293-
if not self._preflight_checks():
294-
event.fail("Missing DB or S3 relations")
295-
return
296-
297-
if not self._process_action_transition(DPBenchmarkLifecycleTransition.CLEAN):
298-
event.fail("Failed to clean the benchmark")
299-
event.set_results({"message": "Benchmark is cleaning"})
300-
301-
def _process_action_transition(self, transition: DPBenchmarkLifecycleTransition) -> bool:
302-
"""Process the action."""
303-
# First, check if we have an update in our lifecycle state
304-
self._update_state()
305-
306-
if not (state := self.lifecycle.next(transition)):
307-
return False
308-
309-
self.lifecycle.make_transition(state)
310-
self.unit.status = self.lifecycle.status
311-
return True
312-
313179
###########################################################################
314180
#
315181
# Helpers
@@ -318,9 +184,14 @@ def _process_action_transition(self, transition: DPBenchmarkLifecycleTransition)
318184

319185
def _unit_ip(self) -> str:
320186
"""Current unit ip."""
321-
return self.model.get_binding(PEER_RELATION).network.bind_address
187+
bind_address = None
188+
if PEER_RELATION:
189+
if binding := self.model.get_binding(PEER_RELATION):
190+
bind_address = binding.network.bind_address
191+
192+
return str(bind_address) if bind_address else ""
322193

323-
def _update_state(self) -> None:
194+
def update_state(self) -> None:
324195
"""Update the state of the charm."""
325196
if (next_state := self.lifecycle.next(None)) and self.lifecycle.current() != next_state:
326197
self.lifecycle.make_transition(next_state)

src/benchmark/core/models.py

+9-23
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
"""
1010

1111
import logging
12-
from typing import Any, Optional
12+
from typing import Any, MutableMapping, Optional
1313

1414
from ops.model import Application, Relation, Unit
1515
from overrides import override
1616
from pydantic import BaseModel, error_wrappers, root_validator
1717

1818
from benchmark.literals import (
1919
LIFECYCLE_KEY,
20-
STOP_KEY,
2120
DPBenchmarkLifecycleState,
2221
DPBenchmarkMissingOptionsError,
2322
Scope,
@@ -106,7 +105,6 @@ class DPBenchmarkWrapperOptionsModel(BaseModel):
106105
workload_name: str
107106
db_info: DPBenchmarkBaseDatabaseModel
108107
report_interval: int
109-
workload_profile: str
110108
labels: str
111109
peers: str | None = None
112110

@@ -125,20 +123,18 @@ def __init__(
125123
self.scope = scope
126124

127125
@property
128-
def relation_data(self) -> dict[str, str]:
126+
def relation_data(self) -> MutableMapping[str, str]:
129127
"""Returns the relation data."""
130128
if self.relation:
131129
return self.relation.data[self.component]
132130
return {}
133131

134132
@property
135-
def remote_data(self) -> dict[str, str]:
133+
def remote_data(self) -> MutableMapping[str, str]:
136134
"""Returns the remote relation data."""
137-
if not self.relation:
135+
if not self.relation or self.scope != Scope.APP:
138136
return {}
139-
if self.scope == Scope.APP:
140-
return self.relation.data[self.relation.app]
141-
return self.relation.data[self.relation.unit]
137+
return self.relation.data[self.relation.app]
142138

143139
def __bool__(self) -> bool:
144140
"""Boolean evaluation based on the existence of self.relation."""
@@ -191,16 +187,6 @@ def lifecycle(self, status: DPBenchmarkLifecycleState | str) -> None:
191187
else:
192188
self.set({LIFECYCLE_KEY: status})
193189

194-
@property
195-
def stop(self) -> bool:
196-
"""Returns the value of the stop key."""
197-
return self.relation_data.get(STOP_KEY, False)
198-
199-
@stop.setter
200-
def stop(self, switch: bool) -> bool:
201-
"""Toggles the stop key value."""
202-
self.set({STOP_KEY: switch})
203-
204190

205191
class DatabaseState(RelationState):
206192
"""State collection for the database relation."""
@@ -236,7 +222,7 @@ def tls_ca(self) -> str | None:
236222
return None
237223
return tls_ca
238224

239-
def get(self) -> DPBenchmarkBaseDatabaseModel | None:
225+
def model(self) -> DPBenchmarkBaseDatabaseModel | None:
240226
"""Returns the value of the key."""
241227
if not self.relation or not (endpoints := self.remote_data.get("endpoints")):
242228
return None
@@ -248,9 +234,9 @@ def get(self) -> DPBenchmarkBaseDatabaseModel | None:
248234
return DPBenchmarkBaseDatabaseModel(
249235
hosts=endpoints.split(),
250236
unix_socket=unix_socket,
251-
username=self.data.get("username"),
252-
password=self.data.get("password"),
253-
db_name=self.remote_data.get(self.database_key),
237+
username=self.data.get("username", ""),
238+
password=self.data.get("password", ""),
239+
db_name=self.remote_data.get(self.database_key, ""),
254240
tls=self.tls,
255241
tls_ca=self.tls_ca,
256242
)

src/benchmark/core/pebble_workload_base.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,9 @@
2020
class DPBenchmarkPebbleTemplatePaths(WorkloadTemplatePaths):
2121
"""Represents the benchmark service template paths."""
2222

23-
def __init__(self):
24-
super().__init__()
25-
self.svc_name = "dpe_benchmark"
26-
2723
@property
2824
@override
29-
def service(self) -> str | None:
25+
def service(self) -> str:
3026
"""The optional path to the service file managing the script."""
3127
return f"/etc/systemd/system/{self.svc_name}.service"
3228

@@ -44,6 +40,12 @@ def templates(self) -> str:
4440
"""The path to the workload template folder."""
4541
return os.path.join(os.environ.get("CHARM_DIR", ""), "templates")
4642

43+
@property
44+
@override
45+
def results(self) -> str:
46+
"""The path to the results folder."""
47+
return "/root/.benchmark/charmed_parameters/results/"
48+
4749
@property
4850
@override
4951
def service_template(self) -> str:

0 commit comments

Comments
 (0)