diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 25727e1..803c0bd 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -14,6 +14,18 @@ on: workflow_call: jobs: + lint: + name: Lint + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Install tox + run: pipx install tox + - name: Run linters + run: tox run -e lint + unit-test: name: Unit tests runs-on: ubuntu-latest diff --git a/src/benchmark/base_charm.py b/src/benchmark/base_charm.py index fdce6e3..5a131f8 100644 --- a/src/benchmark/base_charm.py +++ b/src/benchmark/base_charm.py @@ -16,10 +16,9 @@ import logging import subprocess -from abc import ABC, abstractmethod from typing import Any -import ops +from charms.data_platform_libs.v0.data_models import TypedCharmBase from charms.grafana_agent.v0.cos_agent import COSAgentProvider from ops.charm import CharmEvents from ops.framework import EventBase, EventSource @@ -27,15 +26,16 @@ from benchmark.core.models import DPBenchmarkLifecycleState from benchmark.core.pebble_workload_base import DPBenchmarkPebbleWorkloadBase +from benchmark.core.structured_config import BenchmarkCharmConfig from benchmark.core.systemd_workload_base import DPBenchmarkSystemdWorkloadBase from benchmark.core.workload_base import WorkloadBase +from benchmark.events.actions import ActionsHandler from benchmark.events.db import DatabaseRelationHandler from benchmark.events.peer import PeerRelationHandler from benchmark.literals import ( COS_AGENT_RELATION, METRICS_PORT, PEER_RELATION, - DPBenchmarkLifecycleTransition, DPBenchmarkMissingOptionsError, ) from benchmark.managers.config import ConfigManager @@ -70,34 +70,22 @@ def workload_build(workload_params_template: str) -> WorkloadBase: return DPBenchmarkSystemdWorkloadBase(workload_params_template) -class DPBenchmarkCharmBase(ops.CharmBase, ABC): +class DPBenchmarkCharmBase(TypedCharmBase[BenchmarkCharmConfig]): """The base benchmark class.""" - on = DPBenchmarkEvents() # pyright: ignore [reportGeneralTypeIssues] + on = DPBenchmarkEvents() # pyright: ignore [reportAssignmentType] RESOURCE_DEB_NAME = "benchmark-deb" workload_params_template = "" + config_type = BenchmarkCharmConfig + def __init__(self, *args, db_relation_name: str, workload: WorkloadBase | None = None): super().__init__(*args) self.framework.observe(self.on.install, self._on_install) self.framework.observe(self.on.config_changed, self._on_config_changed) self.framework.observe(self.on.update_status, self._on_update_status) - self.framework.observe(self.on.prepare_action, self.on_prepare_action) - self.framework.observe(self.on.run_action, self.on_run_action) - self.framework.observe(self.on.stop_action, self.on_stop_action) - self.framework.observe(self.on.cleanup_action, self.on_clean_action) - - self.framework.observe( - self.on.check_upload, - self._on_check_upload, - ) - self.framework.observe( - self.on.check_collect, - self._on_check_collect, - ) - self.database = DatabaseRelationHandler(self, db_relation_name) self.peers = PeerRelationHandler(self, PEER_RELATION) self.framework.observe(self.database.on.db_config_update, self._on_config_changed) @@ -119,8 +107,8 @@ def __init__(self, *args, db_relation_name: str, workload: WorkloadBase | None = self.config_manager = ConfigManager( workload=self.workload, - database=self.database.state, - peer=self.peers.peers(), + database_state=self.database.state, + peers=self.peers.peers(), config=self.config, labels=self.labels, ) @@ -129,11 +117,7 @@ def __init__(self, *args, db_relation_name: str, workload: WorkloadBase | None = self.peers.this_unit(), self.config_manager, ) - - @abstractmethod - def supported_workloads(self) -> list[str]: - """List of supported workloads.""" - ... + self.actions = ActionsHandler(self) ########################################################################### # @@ -146,28 +130,6 @@ def _on_install(self, event: EventBase) -> None: self.workload.install() self.peers.state.lifecycle = DPBenchmarkLifecycleState.UNSET - def _on_check_collect(self, event: EventBase) -> None: - """Check if the upload is finished.""" - if self.config_manager.is_collecting(): - # Nothing to do, upload is still in progress - event.defer() - return - - if self.unit.is_leader(): - self.peers.state.set(DPBenchmarkLifecycleState.UPLOADING) - # Raise we are running an upload and we will check the status later - self.on.check_upload.emit() - return - self.peers.state.set(DPBenchmarkLifecycleState.FINISHED) - - def _on_check_upload(self, event: EventBase) -> None: - """Check if the upload is finished.""" - if self.config_manager.is_uploading(): - # Nothing to do, upload is still in progress - event.defer() - return - self.peers.state.lifecycle = DPBenchmarkLifecycleState.FINISHED - def _on_update_status(self, event: EventBase | None = None) -> None: """Set status for the operator and finishes the service. @@ -176,7 +138,7 @@ def _on_update_status(self, event: EventBase | None = None) -> None: benchmark service and the benchmark status. """ try: - status = self.database.state.get() + status = self.database.state.model() except DPBenchmarkMissingOptionsError as e: self.unit.status = BlockedStatus(str(e)) return @@ -184,26 +146,12 @@ def _on_update_status(self, event: EventBase | None = None) -> None: self.unit.status = BlockedStatus("No database relation available") return - # We need to narrow the options of workload_name to the supported ones - if self.config.get("workload_name") not in self.supported_workloads(): - self.unit.status = BlockedStatus( - f"Unsupported workload: {self.config.get('workload_name')}" - ) - return - # Now, let's check if we need to update our lifecycle position - self._update_state() + self.update_state() self.unit.status = self.lifecycle.status def _on_config_changed(self, event: EventBase) -> None: """Config changed event.""" - # We need to narrow the options of workload_name to the supported ones - if self.config.get("workload_name") not in self.supported_workloads(): - self.unit.status = BlockedStatus( - f"Unsupported workload: {self.config.get('workload_name')}" - ) - return - if not self.config_manager.is_prepared(): # nothing to do: set the status and leave self._on_update_status() @@ -228,88 +176,6 @@ def scrape_config(self) -> list[dict[str, Any]]: } ] - ########################################################################### - # - # Action and Lifecycle Handlers - # - ########################################################################### - - def _preflight_checks(self) -> bool: - """Check if we have the necessary relations.""" - if len(self.peers.units()) > 0 and not bool(self.peers.state.get()): - return False - try: - return bool(self.database.state.get()) - except DPBenchmarkMissingOptionsError: - return False - - def on_prepare_action(self, event: EventBase) -> None: - """Process the prepare action.""" - if not self._preflight_checks(): - event.fail("Missing DB or S3 relations") - return - - if not (state := self.lifecycle.next(DPBenchmarkLifecycleTransition.PREPARE)): - event.fail("Failed to prepare the benchmark: already done") - return - - if state != DPBenchmarkLifecycleState.PREPARING: - event.fail( - "Another peer is already in prepare state. Wait or call clean action to reset." - ) - return - - # We process the special case of PREPARE, as explained in lifecycle.make_transition() - if not self.config_manager.prepare(): - event.fail("Failed to prepare the benchmark") - return - - self.lifecycle.make_transition(state) - self.unit.status = self.lifecycle.status - event.set_results({"message": "Benchmark is being prepared"}) - - def on_run_action(self, event: EventBase) -> None: - """Process the run action.""" - if not self._preflight_checks(): - event.fail("Missing DB or S3 relations") - return - - if not self._process_action_transition(DPBenchmarkLifecycleTransition.RUN): - event.fail("Failed to run the benchmark") - event.set_results({"message": "Benchmark has started"}) - - def on_stop_action(self, event: EventBase) -> None: - """Process the stop action.""" - if not self._preflight_checks(): - event.fail("Missing DB or S3 relations") - return - - if not self._process_action_transition(DPBenchmarkLifecycleTransition.STOP): - event.fail("Failed to stop the benchmark") - event.set_results({"message": "Benchmark has stopped"}) - - def on_clean_action(self, event: EventBase) -> None: - """Process the clean action.""" - if not self._preflight_checks(): - event.fail("Missing DB or S3 relations") - return - - if not self._process_action_transition(DPBenchmarkLifecycleTransition.CLEAN): - event.fail("Failed to clean the benchmark") - event.set_results({"message": "Benchmark is cleaning"}) - - def _process_action_transition(self, transition: DPBenchmarkLifecycleTransition) -> bool: - """Process the action.""" - # First, check if we have an update in our lifecycle state - self._update_state() - - if not (state := self.lifecycle.next(transition)): - return False - - self.lifecycle.make_transition(state) - self.unit.status = self.lifecycle.status - return True - ########################################################################### # # Helpers @@ -318,9 +184,14 @@ def _process_action_transition(self, transition: DPBenchmarkLifecycleTransition) def _unit_ip(self) -> str: """Current unit ip.""" - return self.model.get_binding(PEER_RELATION).network.bind_address + bind_address = None + if PEER_RELATION: + if binding := self.model.get_binding(PEER_RELATION): + bind_address = binding.network.bind_address + + return str(bind_address) if bind_address else "" - def _update_state(self) -> None: + def update_state(self) -> None: """Update the state of the charm.""" if (next_state := self.lifecycle.next(None)) and self.lifecycle.current() != next_state: self.lifecycle.make_transition(next_state) diff --git a/src/benchmark/core/models.py b/src/benchmark/core/models.py index 479b983..4978b6a 100644 --- a/src/benchmark/core/models.py +++ b/src/benchmark/core/models.py @@ -11,13 +11,12 @@ import logging from typing import Any, Optional -from ops.model import Application, Relation, Unit +from ops.model import Application, Relation, RelationDataContent, Unit from overrides import override from pydantic import BaseModel, error_wrappers, root_validator from benchmark.literals import ( LIFECYCLE_KEY, - STOP_KEY, DPBenchmarkLifecycleState, DPBenchmarkMissingOptionsError, Scope, @@ -106,7 +105,6 @@ class DPBenchmarkWrapperOptionsModel(BaseModel): workload_name: str db_info: DPBenchmarkBaseDatabaseModel report_interval: int - workload_profile: str labels: str peers: str | None = None @@ -125,20 +123,18 @@ def __init__( self.scope = scope @property - def relation_data(self) -> dict[str, str]: + def relation_data(self) -> RelationDataContent | dict[Any, Any]: """Returns the relation data.""" if self.relation: return self.relation.data[self.component] return {} @property - def remote_data(self) -> dict[str, str]: + def remote_data(self) -> RelationDataContent | dict[Any, Any]: """Returns the remote relation data.""" - if not self.relation: + if not self.relation or self.scope != Scope.APP: return {} - if self.scope == Scope.APP: - return self.relation.data[self.relation.app] - return self.relation.data[self.relation.unit] + return self.relation.data[self.relation.app] def __bool__(self) -> bool: """Boolean evaluation based on the existence of self.relation.""" @@ -191,16 +187,6 @@ def lifecycle(self, status: DPBenchmarkLifecycleState | str) -> None: else: self.set({LIFECYCLE_KEY: status}) - @property - def stop(self) -> bool: - """Returns the value of the stop key.""" - return self.relation_data.get(STOP_KEY, False) - - @stop.setter - def stop(self, switch: bool) -> bool: - """Toggles the stop key value.""" - self.set({STOP_KEY: switch}) - class DatabaseState(RelationState): """State collection for the database relation.""" @@ -236,7 +222,7 @@ def tls_ca(self) -> str | None: return None return tls_ca - def get(self) -> DPBenchmarkBaseDatabaseModel | None: + def model(self) -> DPBenchmarkBaseDatabaseModel | None: """Returns the value of the key.""" if not self.relation or not (endpoints := self.remote_data.get("endpoints")): return None @@ -248,9 +234,9 @@ def get(self) -> DPBenchmarkBaseDatabaseModel | None: return DPBenchmarkBaseDatabaseModel( hosts=endpoints.split(), unix_socket=unix_socket, - username=self.data.get("username"), - password=self.data.get("password"), - db_name=self.remote_data.get(self.database_key), + username=self.data.get("username", ""), + password=self.data.get("password", ""), + db_name=self.remote_data.get(self.database_key, ""), tls=self.tls, tls_ca=self.tls_ca, ) diff --git a/src/benchmark/core/pebble_workload_base.py b/src/benchmark/core/pebble_workload_base.py index 19ae2dd..6dec3d7 100644 --- a/src/benchmark/core/pebble_workload_base.py +++ b/src/benchmark/core/pebble_workload_base.py @@ -11,7 +11,6 @@ from charms.operator_libs_linux.v1.systemd import ( service_restart, - service_stop, ) from overrides import override @@ -21,13 +20,9 @@ class DPBenchmarkPebbleTemplatePaths(WorkloadTemplatePaths): """Represents the benchmark service template paths.""" - def __init__(self): - super().__init__() - self.svc_name = "dpe_benchmark" - @property @override - def service(self) -> str | None: + def service(self) -> str: """The optional path to the service file managing the script.""" return f"/etc/systemd/system/{self.svc_name}.service" @@ -45,6 +40,12 @@ def templates(self) -> str: """The path to the workload template folder.""" return os.path.join(os.environ.get("CHARM_DIR", ""), "templates") + @property + @override + def results(self) -> str: + """The path to the results folder.""" + return "/root/.benchmark/charmed_parameters/results/" + @property @override def service_template(self) -> str: @@ -88,7 +89,7 @@ def halt(self) -> bool: @override def reload(self) -> bool: """Reloads the workload service.""" - ... + ... @override def read(self, path: str) -> list[str]: diff --git a/src/benchmark/core/structured_config.py b/src/benchmark/core/structured_config.py new file mode 100644 index 0000000..1a905b0 --- /dev/null +++ b/src/benchmark/core/structured_config.py @@ -0,0 +1,32 @@ +# Copyright 2025 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Structured configuration for the Kafka charm.""" + +import logging + +from charms.data_platform_libs.v0.data_models import BaseConfigModel +from pydantic import Field, validator + +logger = logging.getLogger(__name__) + + +class BenchmarkCharmConfig(BaseConfigModel): + """Manager for the structured configuration.""" + + test_name: str = Field(default="", validate_default=False) + parallel_processes: int = Field(default=1, validate_default=False, ge=1) + threads: int = Field(default=1, validate_default=False, ge=1) + duration: int = Field(default=0, validate_default=False, ge=0) + run_count: int = Field(default=0, validate_default=False, ge=0) + workload_name: str = Field(default="default", validate_default=True) + override_access_hostname: str = Field(default="", validate_default=False) + report_interval: int = Field(default=1, validate_default=False, ge=1) + + @validator("*", pre=True) + @classmethod + def blank_string(cls, value): + """Check for empty strings.""" + if value == "": + return None + return value diff --git a/src/benchmark/core/systemd_workload_base.py b/src/benchmark/core/systemd_workload_base.py index 329aeb2..bce0b14 100644 --- a/src/benchmark/core/systemd_workload_base.py +++ b/src/benchmark/core/systemd_workload_base.py @@ -33,7 +33,7 @@ def __init__(self): @property @override - def service(self) -> str | None: + def service(self) -> str: """The optional path to the service file managing the script.""" return f"/etc/systemd/system/{self.svc_name}.service" @@ -52,6 +52,7 @@ def workload_params(self) -> str: return "/root/.benchmark/charmed_parameters/" + self.svc_name + ".json" @property + @override def results(self) -> str: """The path to the results folder.""" return "/root/.benchmark/charmed_parameters/results/" @@ -98,7 +99,7 @@ def halt(self) -> bool: @override def reload(self) -> bool: """Reloads the script.""" - daemon_reload() + return daemon_reload() @override def read(self, path: str) -> list[str]: @@ -142,7 +143,7 @@ def exec( ) except subprocess.CalledProcessError: return None - return output or "" + return output.stdout.decode() if output.stdout else None @override def is_active(self) -> bool: diff --git a/src/benchmark/core/workload_base.py b/src/benchmark/core/workload_base.py index 30a501a..e1dc679 100644 --- a/src/benchmark/core/workload_base.py +++ b/src/benchmark/core/workload_base.py @@ -22,13 +22,13 @@ def bin(self) -> str: @property @abstractmethod - def service(self) -> str | None: + def service(self) -> str: """The optional path to the service file managing the python wrapper.""" ... @property @abstractmethod - def service_template(self) -> str | None: + def service_template(self) -> str: """The path to the service template file.""" ... diff --git a/src/benchmark/events/actions.py b/src/benchmark/events/actions.py new file mode 100644 index 0000000..55832e3 --- /dev/null +++ b/src/benchmark/events/actions.py @@ -0,0 +1,147 @@ +# Copyright 2025 Canonical Ltd. +# See LICENSE file for licensing details. + +"""This module abstracts the different DBs and provide a single API set. + +The DatabaseRelationHandler listens to DB events and manages the relation lifecycles. +The charm interacts with the manager and requests data + listen to some key events such +as changes in the configuration. +""" + +import logging + +from ops.charm import ActionEvent +from ops.framework import EventBase + +from benchmark.base_charm import DPBenchmarkCharmBase +from benchmark.core.models import DPBenchmarkLifecycleState +from benchmark.literals import ( + DPBenchmarkLifecycleTransition, + DPBenchmarkMissingOptionsError, +) + +logger = logging.getLogger(__name__) + + +class ActionsHandler: + """Handle the actions for the benchmark charm.""" + + def __init__(self, charm: DPBenchmarkCharmBase): + """Initialize the class.""" + self.charm = charm + self.database = charm.database + self.lifecycle = charm.lifecycle + self.framework = charm.framework + self.config_manager = charm.config_manager + self.peers = charm.peers + self.unit = charm.unit + + self.framework.observe(self.charm.on.prepare_action, self.on_prepare_action) + self.framework.observe(self.charm.on.run_action, self.on_run_action) + self.framework.observe(self.charm.on.stop_action, self.on_stop_action) + self.framework.observe(self.charm.on.cleanup_action, self.on_clean_action) + + self.framework.observe( + self.charm.on.check_upload, + self._on_check_upload, + ) + self.framework.observe( + self.charm.on.check_collect, + self._on_check_collect, + ) + + def _on_check_collect(self, event: EventBase) -> None: + """Check if the upload is finished.""" + if self.config_manager.is_collecting(): + # Nothing to do, upload is still in progress + event.defer() + return + + if self.unit.is_leader(): + self.peers.state.lifecycle = DPBenchmarkLifecycleState.UPLOADING + # Raise we are running an upload and we will check the status later + self.charm.on.check_upload.emit() + return + self.peers.state.lifecycle = DPBenchmarkLifecycleState.FINISHED + + def _on_check_upload(self, event: EventBase) -> None: + """Check if the upload is finished.""" + if self.config_manager.is_uploading(): + # Nothing to do, upload is still in progress + event.defer() + return + self.peers.state.lifecycle = DPBenchmarkLifecycleState.FINISHED + + def _preflight_checks(self) -> bool: + """Check if we have the necessary relations.""" + try: + return bool(self.database.state.model()) + except DPBenchmarkMissingOptionsError: + return False + + def on_prepare_action(self, event: ActionEvent) -> None: + """Process the prepare action.""" + if not self._preflight_checks(): + event.fail("Missing DB or S3 relations") + return + + if not (state := self.lifecycle.next(DPBenchmarkLifecycleTransition.PREPARE)): + event.fail("Failed to prepare the benchmark: already done") + return + + if state != DPBenchmarkLifecycleState.PREPARING: + event.fail( + "Another peer is already in prepare state. Wait or call clean action to reset." + ) + return + + # We process the special case of PREPARE, as explained in lifecycle.make_transition() + if not self.config_manager.prepare(): + event.fail("Failed to prepare the benchmark") + return + + self.lifecycle.make_transition(state) + self.unit.status = self.lifecycle.status + event.set_results({"message": "Benchmark is being prepared"}) + + def on_run_action(self, event: ActionEvent) -> None: + """Process the run action.""" + if not self._preflight_checks(): + event.fail("Missing DB or S3 relations") + return + + if not self._process_action_transition(DPBenchmarkLifecycleTransition.RUN): + event.fail("Failed to run the benchmark") + event.set_results({"message": "Benchmark has started"}) + + def on_stop_action(self, event: ActionEvent) -> None: + """Process the stop action.""" + if not self._preflight_checks(): + event.fail("Missing DB or S3 relations") + return + + if not self._process_action_transition(DPBenchmarkLifecycleTransition.STOP): + event.fail("Failed to stop the benchmark") + event.set_results({"message": "Benchmark has stopped"}) + + def on_clean_action(self, event: ActionEvent) -> None: + """Process the clean action.""" + if not self._preflight_checks(): + event.fail("Missing DB or S3 relations") + return + + if not self._process_action_transition(DPBenchmarkLifecycleTransition.CLEAN): + event.fail("Failed to clean the benchmark") + event.set_results({"message": "Benchmark is cleaning"}) + + def _process_action_transition(self, transition: DPBenchmarkLifecycleTransition) -> bool: + """Process the action.""" + # First, check if we have an update in our lifecycle state + self.charm.update_state() + + if not (state := self.lifecycle.next(transition)): + return False + + self.lifecycle.make_transition(state) + self.unit.status = self.lifecycle.status + return True diff --git a/src/benchmark/events/db.py b/src/benchmark/events/db.py index a6d8b7c..8b01106 100644 --- a/src/benchmark/events/db.py +++ b/src/benchmark/events/db.py @@ -9,11 +9,13 @@ """ import logging +from abc import abstractmethod from charms.data_platform_libs.v0.data_interfaces import DatabaseRequires from ops.charm import CharmBase, CharmEvents from ops.framework import EventBase, EventSource +from benchmark.core.models import DatabaseState from benchmark.events.handler import RelationHandler from benchmark.literals import DPBenchmarkMissingOptionsError @@ -37,7 +39,7 @@ class DatabaseRelationHandler(RelationHandler): well as the current relation status. """ - on = DatabaseHandlerEvents() # pyright: ignore [reportGeneralTypeIssues] + on = DatabaseHandlerEvents() # pyright: ignore [reportAssignmentType] def __init__( self, @@ -58,48 +60,10 @@ def __init__( self.charm.on[self.relation_name].relation_broken, self._on_endpoints_changed ) - # @property - # def username(self) -> str|None: - # """Returns the username to connect to the database.""" - # return (self._secret_user or {}).get("username") - - # @property - # def password(self) -> str|None: - # """Returns the password to connect to the database.""" - # return (self._secret_user or {}).get("password") - - # @property - # def tls(self) -> str|None: - # """Returns the TLS to connect to the database.""" - # tls = (self._secret_tls or {}).get("tls") - # if not tls or tls == "disabled": - # return None - # return tls - - # @property - # def tls_ca(self) -> str|None: - # """Returns the TLS CA to connect to the database.""" - # tls_ca = (self._secret_user or {}).get("tls_ca") - # if not tls_ca or tls_ca == "disabled": - # return None - # return tls_ca - - # @property - # def _secret_user(self) -> dict[str, str]|None: - # if not (secret_id := self.client.fetch_relation_data()[self.relation.id].get("secret-user")): - # return None - # return self.charm.framework.model.get_secret(id=secret_id).get_content() - - # @property - # def _secret_tls(self) -> dict[str, str]|None: - # if not (secret_id := self.client.fetch_relation_data()[self.relation.id].get("secret-tls")): - # return None - # return self.charm.framework.model.get_secret(id=secret_id).get_content() - - def _on_endpoints_changed(self, event: EventBase) -> None: + def _on_endpoints_changed(self, _: EventBase) -> None: """Handles the endpoints_changed event.""" try: - if self.state.get(): + if self.state.model(): self.on.db_config_update.emit() except DPBenchmarkMissingOptionsError as e: logger.warning(f"Missing options: {e}") @@ -109,3 +73,9 @@ def _on_endpoints_changed(self, event: EventBase) -> None: def client(self) -> DatabaseRequires: """Returns the data_interfaces client corresponding to the database.""" ... + + @property + @abstractmethod + def state(self) -> DatabaseState: + """Returns the state of the database.""" + ... diff --git a/src/benchmark/literals.py b/src/benchmark/literals.py index fa83cd1..7952a34 100644 --- a/src/benchmark/literals.py +++ b/src/benchmark/literals.py @@ -10,7 +10,6 @@ # Peer relation keys LIFECYCLE_KEY = "lifecycle" -STOP_KEY = "stop" class Substrate(str, Enum): diff --git a/src/benchmark/managers/collector.py b/src/benchmark/managers/collector.py deleted file mode 100644 index 94ebd92..0000000 --- a/src/benchmark/managers/collector.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2024 Canonical Ltd. -# See LICENSE file for licensing details. - -"""The collector class. - -This class runs all the collection tasks for a given result. -""" - -from benchmark.core.models import SosreportCLIArgsModel -from benchmark.core.workload_base import WorkloadBase - - -class CollectorManager: - """The collector manager class.""" - - def __init__( - self, - workload: WorkloadBase, - sosreport_config: SosreportCLIArgsModel | None = None, - ): - # TODO: we need a way to run "sos collect" - # For that, we will have to manage ssh keys between the peers - # E.G.: - # sudo sos collect \ - # -i ~/.local/share/juju/ssh/juju_id_rsa --ssh-user ubuntu --no-local \ - # --nodes "$NODES" \ - # --only-plugins systemd,logs,juju \ - # -k logs.all_logs=true \ - # --batch \ - # --clean \ - # --tmp-dir=/tmp/sos \ - # -z gzip -j 1 - self.workload = workload - if not sosreport_config: - if workload.is_running_on_k8s(): - self.sosreport_config = SosreportCLIArgsModel( - plugins=["systemd", "logs", "juju"], - ) - else: - self.sosreport_config = SosreportCLIArgsModel( - plugins=["logs", "juju"], - ) - self.sosreport_config = sosreport_config - - def install(self) -> bool: - """Installs the collector.""" - ... - - def collect_sosreport(self) -> bool: - """Collect the sosreport.""" - self.workload.exec( - command=["sosreport"] + str(self.sosreport_config).split(), - ) diff --git a/src/benchmark/managers/config.py b/src/benchmark/managers/config.py index 399b332..7387915 100644 --- a/src/benchmark/managers/config.py +++ b/src/benchmark/managers/config.py @@ -19,6 +19,7 @@ DatabaseState, DPBenchmarkWrapperOptionsModel, ) +from benchmark.core.structured_config import BenchmarkCharmConfig from benchmark.core.workload_base import WorkloadBase from benchmark.literals import DPBenchmarkLifecycleTransition @@ -34,7 +35,7 @@ def __init__( workload: WorkloadBase, database_state: DatabaseState, peers: list[str], - config: dict[str, Any], + config: BenchmarkCharmConfig, labels: str, ): self.workload = workload @@ -61,7 +62,7 @@ def is_cleaned(self) -> bool: @property def _test_name(self) -> str: """Return the test name.""" - return self.config.get("test_name") or "dpe-benchmark" + return self.config.test_name or "dpe-benchmark" @property def test_name(self) -> str: @@ -76,20 +77,19 @@ def get_execution_options( Raises: DPBenchmarkMissingOptionsError: If the database is not ready. """ - if not (db := self.database_state.get()): + if not (db := self.database_state.model()): # It means we are not yet ready. Return None # This check also serves to ensure we have only one valid relation at the time return None return DPBenchmarkWrapperOptionsModel( test_name=self.test_name, - parallel_processes=self.config.get("parallel_processes"), - threads=self.config.get("threads"), - duration=self.config.get("duration"), - run_count=self.config.get("run_count"), + parallel_processes=self.config.parallel_processes, + threads=self.config.threads, + duration=self.config.duration, + run_count=self.config.run_count, db_info=db, - workload_name=self.config.get("workload_name"), - report_interval=self.config.get("report_interval"), - workload_profile=self.config.get("workload_profile"), + workload_name=self.config.workload_name, + report_interval=self.config.report_interval, labels=self.labels, peers=",".join(self.peers), ) @@ -174,7 +174,7 @@ def is_failed( def _render_params( self, - dst_path: str | None = None, + dst_path: str, ) -> str | None: """Render the workload parameters.""" return self._render( @@ -190,7 +190,9 @@ def _render_service( dst_path: str | None = None, ) -> str | None: """Render the workload parameters.""" - values = self.get_execution_options().dict() | { + if not (options := self.get_execution_options()): + return None + values = options.dict() | { "charm_root": os.environ.get("CHARM_DIR", ""), "command": transition.value, } @@ -216,7 +218,9 @@ def _check( "command": transition.value, "target_hosts": values.db_info.hosts, } - compare_svc = "\n".join(self.workload.read(self.workload.paths.service)) == self._render( + compare_svc = "\n".join( + self.workload.read(self.workload.paths.service) or "" + ) == self._render( values=values, template_file=self.workload.paths.service_template, template_content=None, @@ -239,7 +243,7 @@ def _render( template_file: str | None, template_content: str | None, dst_filepath: str | None = None, - ) -> str: + ) -> str | None: """Renders from a file or an string content and return final rendered value.""" try: if template_file: @@ -247,7 +251,7 @@ def _render( template = template_env.get_template(template_file) else: template_env = Environment( - loader=DictLoader({"workload_params": template_content}) + loader=DictLoader({"workload_params": template_content or ""}) ) template = template_env.get_template("workload_params") content = template.render(values) diff --git a/src/benchmark/managers/lifecycle.py b/src/benchmark/managers/lifecycle.py index 7869807..9dfd139 100644 --- a/src/benchmark/managers/lifecycle.py +++ b/src/benchmark/managers/lifecycle.py @@ -3,6 +3,9 @@ """The lifecycle manager class.""" +from abc import ABC, abstractmethod +from typing import Optional + from ops.model import ( ActiveStatus, BlockedStatus, @@ -26,7 +29,7 @@ class LifecycleManager: def __init__( self, peers: dict[Unit, PeerState], - this_unit: PeerState, + this_unit: Unit, config_manager: ConfigManager, ): self.peers = peers @@ -101,127 +104,16 @@ def make_transition(self, new_state: DPBenchmarkLifecycleState) -> bool: # noqa self.peers[self.this_unit].lifecycle = new_state.value return True - def next( # noqa: C901 + def next( self, transition: DPBenchmarkLifecycleTransition | None = None ) -> DPBenchmarkLifecycleState | None: """Return the next lifecycle state.""" - # Changes that takes us to UNSET: - if transition == DPBenchmarkLifecycleTransition.CLEAN: - # Simplest case, we return to unset - return DPBenchmarkLifecycleState.UNSET - - # Changes that takes us to STOPPED: - # Either we received a stop transition - if transition == DPBenchmarkLifecycleTransition.STOP: - return DPBenchmarkLifecycleState.STOPPED - # OR one of our peers is in stopped state - if ( - self._compare_lifecycle_states( - self._peers_state(), - DPBenchmarkLifecycleState.STOPPED, - ) - == 0 - ): - return DPBenchmarkLifecycleState.STOPPED - - # FAILED takes precedence over all other states - # Changes that takes us to FAILED: - # Workload has failed and we were: - # - PREPARING - # - RUNNING - # - COLLECTING - # - UPLOADING - if ( - self.current() - in [ - DPBenchmarkLifecycleState.PREPARING, - DPBenchmarkLifecycleState.RUNNING, - DPBenchmarkLifecycleState.COLLECTING, - DPBenchmarkLifecycleState.UPLOADING, - ] - and self.config_manager.workload.is_failed() - ): - return DPBenchmarkLifecycleState.FAILED - - # Changes that takes us to PREPARING: - # We received a prepare signal and no one else is available yet or we failed previously - if transition == DPBenchmarkLifecycleTransition.PREPARE and self._peers_state() in [ - DPBenchmarkLifecycleState.UNSET, - DPBenchmarkLifecycleState.FAILED, - ]: - return DPBenchmarkLifecycleState.PREPARING - elif transition == DPBenchmarkLifecycleTransition.PREPARE: - # Failed to calculate a proper state as we have neighbors in more advanced state for now - return None - - # Changes that takes us to AVAILABLE: - # Either we were in preparing and we are finished - if ( - self.current() == DPBenchmarkLifecycleState.PREPARING - and self.config_manager.is_prepared() - ): - return DPBenchmarkLifecycleState.AVAILABLE - # OR highest peers state is AVAILABLE but no actions has happened - if ( - transition is None - and self._compare_lifecycle_states( - self._peers_state(), - DPBenchmarkLifecycleState.AVAILABLE, - ) - == 0 - ): - return DPBenchmarkLifecycleState.AVAILABLE - - # Changes that takes us to RUNNING: - # Either we receive a transition to running and we were in one of: - # - AVAILABLE - # - FAILED - # - STOPPED - # - FINISHED - if transition == DPBenchmarkLifecycleTransition.RUN and self.current() in [ - DPBenchmarkLifecycleState.AVAILABLE, - DPBenchmarkLifecycleState.FAILED, - DPBenchmarkLifecycleState.STOPPED, - DPBenchmarkLifecycleState.FINISHED, - ]: - return DPBenchmarkLifecycleState.RUNNING - # OR any other peer is beyond the >=RUNNING state - # and we are still AVAILABLE. - if self._compare_lifecycle_states( - self._peers_state(), - DPBenchmarkLifecycleState.RUNNING, - ) == 0 and self.current() in [ - DPBenchmarkLifecycleState.UNSET, - DPBenchmarkLifecycleState.AVAILABLE, - ]: - return DPBenchmarkLifecycleState.RUNNING - - # Changes that takes us to COLLECTING: - # the workload is in collecting state - if self.config_manager.is_collecting(): - return DPBenchmarkLifecycleState.COLLECTING - - # Changes that takes us to UPLOADING: - # the workload is in uploading state - if self.config_manager.is_uploading(): - return DPBenchmarkLifecycleState.UPLOADING - - # Changes that takes us to FINISHED: - # Workload has finished and we were in one of: - # - RUNNING - # - UPLOADING - if ( - self.current() - in [ - DPBenchmarkLifecycleState.RUNNING, - DPBenchmarkLifecycleState.UPLOADING, - ] - and self.config_manager.workload.is_halted() - ): - return DPBenchmarkLifecycleState.FINISHED - - # We are in an incongruent state OR the transition does not make sense - return None + lifecycle_state = _LifecycleStateFactory().build( + self, + self.current(), + ) + result = lifecycle_state.next(transition) + return result.state if result else None def _peers_state(self) -> DPBenchmarkLifecycleState | None: next_state = self.peers[self.this_unit].lifecycle @@ -229,10 +121,20 @@ def _peers_state(self) -> DPBenchmarkLifecycleState | None: neighbor = self.peers[unit].lifecycle if neighbor is None: continue - elif self._compare_lifecycle_states(neighbor, next_state) > 0: + elif not next_state or self._compare_lifecycle_states(neighbor, next_state) > 0: next_state = neighbor return next_state or DPBenchmarkLifecycleState.UNSET + def check_all_peers_in_state(self, state: DPBenchmarkLifecycleState) -> bool: + """Check if the unit can run the workload. + + That happens if all the peers are set as state value. + """ + for unit in self.peers.keys(): + if state != self.peers[unit].lifecycle: + return False + return True + @property def status(self) -> StatusBase: """Return the status of the benchmark.""" @@ -289,3 +191,230 @@ def _get_value(phase: DPBenchmarkLifecycleState) -> int: # noqa: C901 return 8 return _get_value(neighbor) - _get_value(this) + + +class _LifecycleState(ABC): + """The lifecycle state represents a single state and encapsulates the transition logic.""" + + state: DPBenchmarkLifecycleState + + def __init__(self, manager: LifecycleManager): + self.manager = manager + + @abstractmethod + def next( + self, transition: Optional[DPBenchmarkLifecycleTransition] = None + ) -> Optional["_LifecycleState"]: + """Returns the next state given a transition request.""" + ... + + +class _StoppedLifecycleState(_LifecycleState): + """The stopped lifecycle state.""" + + state = DPBenchmarkLifecycleState.STOPPED + + def next( + self, transition: DPBenchmarkLifecycleTransition | None = None + ) -> Optional["_LifecycleState"]: + if transition == DPBenchmarkLifecycleTransition.CLEAN: + return _UnsetLifecycleState(self.manager) + + if self.manager.config_manager.is_running(): + return _RunningLifecycleState(self.manager) + + if transition == DPBenchmarkLifecycleTransition.RUN: + return _RunningLifecycleState(self.manager) + + if self.manager.config_manager.is_failed(): + return _FailedLifecycleState(self.manager) + + return None + + +class _FailedLifecycleState(_LifecycleState): + """The failed lifecycle state.""" + + state = DPBenchmarkLifecycleState.FAILED + + def next( + self, transition: DPBenchmarkLifecycleTransition | None = None + ) -> Optional["_LifecycleState"]: + if transition == DPBenchmarkLifecycleTransition.CLEAN: + return _UnsetLifecycleState(self.manager) + + if self.manager.config_manager.is_running(): + return _RunningLifecycleState(self.manager) + + if transition == DPBenchmarkLifecycleTransition.RUN: + return _RunningLifecycleState(self.manager) + + return None + + +class _FinishedLifecycleState(_LifecycleState): + """The finished lifecycle state.""" + + state = DPBenchmarkLifecycleState.FINISHED + + def next( + self, transition: DPBenchmarkLifecycleTransition | None = None + ) -> Optional["_LifecycleState"]: + if transition == DPBenchmarkLifecycleTransition.CLEAN: + return _UnsetLifecycleState(self.manager) + + if transition == DPBenchmarkLifecycleTransition.STOP: + return _StoppedLifecycleState(self.manager) + + if self.manager.config_manager.is_running(): + return _RunningLifecycleState(self.manager) + + if transition == DPBenchmarkLifecycleTransition.RUN: + return _RunningLifecycleState(self.manager) + + if self.manager.config_manager.is_failed(): + return _FailedLifecycleState(self.manager) + + return None + + +class _RunningLifecycleState(_LifecycleState): + """The running lifecycle state.""" + + state = DPBenchmarkLifecycleState.RUNNING + + def next( + self, transition: DPBenchmarkLifecycleTransition | None = None + ) -> Optional["_LifecycleState"]: + if transition == DPBenchmarkLifecycleTransition.CLEAN: + return _UnsetLifecycleState(self.manager) + + if transition == DPBenchmarkLifecycleTransition.STOP: + return _StoppedLifecycleState(self.manager) + + if (peer_state := self.manager._peers_state()) and ( + self.manager._compare_lifecycle_states( + peer_state, + DPBenchmarkLifecycleState.STOPPED, + ) + == 0 + ): + return _StoppedLifecycleState(self.manager) + + if self.manager.config_manager.is_failed(): + return _FailedLifecycleState(self.manager) + + if not self.manager.config_manager.is_running(): + # TODO: Collect state should be implemented here instead + return _FinishedLifecycleState(self.manager) + + return None + + +class _AvailableLifecycleState(_LifecycleState): + """The available lifecycle state.""" + + state = DPBenchmarkLifecycleState.AVAILABLE + + def next( + self, transition: DPBenchmarkLifecycleTransition | None = None + ) -> Optional["_LifecycleState"]: + if transition == DPBenchmarkLifecycleTransition.CLEAN: + return _UnsetLifecycleState(self.manager) + + if transition == DPBenchmarkLifecycleTransition.RUN: + return _RunningLifecycleState(self.manager) + + if (peer_state := self.manager._peers_state()) and ( + self.manager._compare_lifecycle_states( + peer_state, + DPBenchmarkLifecycleState.RUNNING, + ) + == 0 + ): + return _RunningLifecycleState(self.manager) + + return None + + +class _PreparingLifecycleState(_LifecycleState): + """The preparing lifecycle state.""" + + state = DPBenchmarkLifecycleState.PREPARING + + def next( + self, transition: DPBenchmarkLifecycleTransition | None = None + ) -> Optional["_LifecycleState"]: + if transition == DPBenchmarkLifecycleTransition.CLEAN: + return _UnsetLifecycleState(self.manager) + + if self.manager.config_manager.is_failed(): + return _FailedLifecycleState(self.manager) + + if self.manager.config_manager.is_prepared(): + return _AvailableLifecycleState(self.manager) + + return None + + +class _UnsetLifecycleState(_LifecycleState): + """The unset lifecycle state.""" + + state = DPBenchmarkLifecycleState.UNSET + + def next( + self, transition: DPBenchmarkLifecycleTransition | None = None + ) -> Optional["_LifecycleState"]: + if transition == DPBenchmarkLifecycleTransition.PREPARE: + return _PreparingLifecycleState(self.manager) + + if (peer_state := self.manager._peers_state()) and ( + self.manager._compare_lifecycle_states( + peer_state, + DPBenchmarkLifecycleState.AVAILABLE, + ) + == 0 + ): + return _AvailableLifecycleState(self.manager) + + if (peer_state := self.manager._peers_state()) and ( + self.manager._compare_lifecycle_states( + peer_state, + DPBenchmarkLifecycleState.RUNNING, + ) + == 0 + ): + return _RunningLifecycleState(self.manager) + + return None + + +class _LifecycleStateFactory: + """The lifecycle state factory.""" + + def build( + self, manager: LifecycleManager, state: DPBenchmarkLifecycleState + ) -> _LifecycleState: + """Build the lifecycle state.""" + if state == DPBenchmarkLifecycleState.UNSET: + return _UnsetLifecycleState(manager) + + if state == DPBenchmarkLifecycleState.PREPARING: + return _PreparingLifecycleState(manager) + + if state == DPBenchmarkLifecycleState.AVAILABLE: + return _AvailableLifecycleState(manager) + + if state == DPBenchmarkLifecycleState.RUNNING: + return _RunningLifecycleState(manager) + + if state == DPBenchmarkLifecycleState.FAILED: + return _FailedLifecycleState(manager) + + if state == DPBenchmarkLifecycleState.FINISHED: + return _FinishedLifecycleState(manager) + + if state == DPBenchmarkLifecycleState.STOPPED: + return _StoppedLifecycleState(manager) + + raise ValueError("Unknown state") diff --git a/src/benchmark/wrapper/core.py b/src/benchmark/wrapper/core.py index d49918c..d967464 100644 --- a/src/benchmark/wrapper/core.py +++ b/src/benchmark/wrapper/core.py @@ -116,18 +116,18 @@ class KafkaBenchmarkSample(BaseModel): class KafkaBenchmarkSampleMatcher(Enum): """Hard-coded regexes to process the benchmark sample.""" - produce_rate: str = r"Pub rate\s+(.*?)\s+msg/s" - produce_throughput: str = r"Pub rate\s+\d+.\d+\s+msg/s\s+/\s+(.*?)\s+MB/s" - produce_error_rate: str = r"Pub err\s+(.*?)\s+err/s" - produce_latency_avg: str = r"Pub Latency \(ms\) avg:\s+(.*?)\s+" + produce_rate = r"Pub rate\s+(.*?)\s+msg/s" + produce_throughput = r"Pub rate\s+\d+.\d+\s+msg/s\s+/\s+(.*?)\s+MB/s" + produce_error_rate = r"Pub err\s+(.*?)\s+err/s" + produce_latency_avg = r"Pub Latency \(ms\) avg:\s+(.*?)\s+" # Match: Pub Latency (ms) avg: 1478.1 - 50%: 1312.6 - 99%: 4981.5 - 99.9%: 5104.7 - Max: 5110.5 # Generates: [('1478.1', '1312.6', '4981.5', '5104.7', '5110.5')] - produce_latency_percentiles: str = r"Pub Latency \(ms\) avg:\s+(.*?)\s+- 50%:\s+(.*?)\s+- 99%:\s+(.*?)\s+- 99.9%:\s+(.*?)\s+- Max:\s+(.*?)\s+" + produce_latency_percentiles = r"Pub Latency \(ms\) avg:\s+(.*?)\s+- 50%:\s+(.*?)\s+- 99%:\s+(.*?)\s+- 99.9%:\s+(.*?)\s+- Max:\s+(.*?)\s+" # Pub Delay Latency (us) avg: 21603452.9 - 50%: 21861759.0 - 99%: 23621631.0 - 99.9%: 24160895.0 - Max: 24163839.0 # Generates: [('21603452.9', '21861759.0', '23621631.0', '24160895.0', '24163839.0')] - produce_latency_delay_percentiles: str = r"Pub Delay Latency \(us\) avg:\s+(.*?)\s+- 50%:\s+(.*?)\s+- 99%:\s+(.*?)\s+- 99.9%:\s+(.*?)\s+- Max:\s+(\d+\.\d+)" + produce_latency_delay_percentiles = r"Pub Delay Latency \(us\) avg:\s+(.*?)\s+- 50%:\s+(.*?)\s+- 99%:\s+(.*?)\s+- 99.9%:\s+(.*?)\s+- Max:\s+(\d+\.\d+)" - consume_rate: str = r"Cons rate\s+(.*?)\s+msg/s" - consume_throughput: str = r"Cons rate\s+\d+.\d+\s+msg/s\s+/\s+(.*?)\s+MB/s" - consume_backlog: str = r"Backlog:\s+(.*?)\s+K" + consume_rate = r"Cons rate\s+(.*?)\s+msg/s" + consume_throughput = r"Cons rate\s+\d+.\d+\s+msg/s\s+/\s+(.*?)\s+MB/s" + consume_backlog = r"Backlog:\s+(.*?)\s+K" diff --git a/src/benchmark/wrapper/main.py b/src/benchmark/wrapper/main.py index 675fbf9..41da76a 100755 --- a/src/benchmark/wrapper/main.py +++ b/src/benchmark/wrapper/main.py @@ -24,6 +24,9 @@ def run(self): """Prepares the workload and runs the benchmark.""" manager, _ = self.mapping.map(self.args.command) + if not manager: + raise ValueError("No manager found for the command") + logging.basicConfig(filename=self.args.log_file, encoding="utf-8", level=logging.INFO) def _exit(*args, **kwargs): diff --git a/src/benchmark/wrapper/process.py b/src/benchmark/wrapper/process.py index 5398158..5945db1 100644 --- a/src/benchmark/wrapper/process.py +++ b/src/benchmark/wrapper/process.py @@ -70,10 +70,12 @@ def start(self): cwd=self.model.cwd, ) # Now, let's make stdout a non-blocking file - os.set_blocking(self._proc.stdout.fileno(), False) + if self._proc: + if self._proc.stdout: + os.set_blocking(self._proc.stdout.fileno(), False) - self.model.pid = self._proc.pid - self.model.status = ProcessStatus.RUNNING + self.model.pid = self._proc.pid + self.model.status = ProcessStatus.RUNNING def status(self) -> ProcessStatus: """Return the status of the process.""" @@ -86,7 +88,9 @@ def status(self) -> ProcessStatus: stat = ProcessStatus.RUNNING elif self._proc.returncode != 0: stat = ProcessStatus.ERROR - self.model.status = stat + + if self.model: + self.model.status = stat return stat async def process( @@ -104,7 +108,7 @@ async def process( or (self.status() == ProcessStatus.RUNNING and self.args.duration == 0) ): to_wait = True - if self._proc: + if self._proc and self._proc.stdout: for line in iter(self._proc.stdout.readline, ""): if output := self.process_line(line): self.metrics.add(output) @@ -140,7 +144,8 @@ def stop(self): self._proc.kill() except Exception as e: logger.warning(f"Error stopping worker: {e}") - self.model.status = ProcessStatus.STOPPED + if self.model: + self.model.status = ProcessStatus.STOPPED @abstractmethod def process_line(self, line: str) -> BaseModel | None: @@ -202,11 +207,15 @@ def __init__(self, args: WorkloadCLIArgsModel, metrics: BenchmarkMetrics): self.manager = None self.metrics = metrics - def status(self) -> ProcessStatus: + def status(self) -> ProcessStatus | None: """Return the status of the benchmark.""" - return self.manager.status() + if self.manager: + return self.manager.status() + return None - def map(self, cmd: BenchmarkCommand) -> tuple[BenchmarkManager, list[BenchmarkProcess]]: + def map( + self, cmd: BenchmarkCommand + ) -> tuple[BenchmarkManager | None, list[BenchmarkProcess] | None]: """Processes high-level arguments into the benchmark manager and workers. Returns all the processes that will be running the benchmark. diff --git a/src/charm.py b/src/charm.py index 5ca2913..40c5be1 100755 --- a/src/charm.py +++ b/src/charm.py @@ -33,9 +33,10 @@ from benchmark.core.models import ( DatabaseState, DPBenchmarkBaseDatabaseModel, - RelationState, ) +from benchmark.core.structured_config import BenchmarkCharmConfig from benchmark.core.workload_base import WorkloadBase +from benchmark.events.actions import ActionsHandler from benchmark.events.db import DatabaseRelationHandler from benchmark.events.peer import PeerRelationHandler from benchmark.literals import ( @@ -45,6 +46,7 @@ from benchmark.managers.config import ConfigManager from benchmark.managers.lifecycle import LifecycleManager from literals import CLIENT_RELATION_NAME, TOPIC_NAME +from models import KafkaBenchmarkCharmConfig # Log messages can be retrieved using juju debug-log logger = logging.getLogger(__name__) @@ -123,12 +125,13 @@ def __init__( ) self.database_key = "topic" - def get(self) -> DPBenchmarkBaseDatabaseModel | None: - """Returns the value of the key.""" + @override + def model(self) -> DPBenchmarkBaseDatabaseModel | None: + """Returns the database model.""" if not self.relation or not (endpoints := self.remote_data.get("endpoints")): return None - - dbmodel = super().get() + if not (dbmodel := super().model()): + return None return DPBenchmarkBaseDatabaseModel( hosts=endpoints.split(","), unix_socket=dbmodel.unix_socket, @@ -168,7 +171,7 @@ def __init__( @property @override - def state(self) -> RelationState: + def state(self) -> KafkaDatabaseState: """Returns the state of the database.""" if not ( self.relation and self.client and self.relation.id in self.client.fetch_relation_data() @@ -192,11 +195,13 @@ def client(self) -> Any: """Returns the data_interfaces client corresponding to the database.""" return self._internal_client - def bootstrap_servers(self) -> str | None: + def bootstrap_servers(self) -> list[str] | None: """Return the bootstrap servers.""" - return self.state.get().hosts + if not self.state or not (model := self.state.model()): + return None + return model.hosts - def tls(self) -> tuple[str, str] | None: + def tls(self) -> tuple[str | None, str | None]: """Return the TLS certificates.""" if not self.state.tls_ca: return self.state.tls, None @@ -224,7 +229,7 @@ def __init__( workload: WorkloadBase, database_state: DatabaseState, peers: list[str], - config: dict[str, Any], + config: BenchmarkCharmConfig, labels: str, ): super().__init__(workload, database_state, peers, config, labels) @@ -237,7 +242,9 @@ def _render_service( dst_path: str | None = None, ) -> str | None: """Render the workload parameters.""" - values = self.get_execution_options().dict() | { + if not (options := self.get_execution_options()): + return None + values = options.dict() | { "charm_root": os.environ.get("CHARM_DIR", ""), "command": transition.value, } @@ -282,18 +289,19 @@ def _check( def get_worker_params(self) -> dict[str, Any]: """Return the workload parameters.""" - db = self.database.state.get() + if not (db := self.database_state.model()): + return {} return { - "total_number_of_brokers": len(self.peer.units()) + 1, + "total_number_of_brokers": len(self.peers) + 1, # We cannot have quotes nor brackets in this string. # Therefore, we render the entire line instead "list_of_brokers_bootstrap": "bootstrap.servers={}".format( - ",".join(self.database.bootstrap_servers()) + ",".join(db.hosts) if db.hosts else "" ), "username": db.username, "password": db.password, - "threads": self.config.get("threads", 1) if self.config.get("threads") > 0 else 1, + "threads": self.config.threads if self.config.threads > 0 else 1, } def _render_worker_params( @@ -312,9 +320,9 @@ def _render_worker_params( def get_workload_params(self) -> dict[str, Any]: """Return the worker parameters.""" return { - "partitionsPerTopic": self.config.get("parallel_processes"), - "duration": int(self.config.get("duration") / 60) - if self.config.get("duration") > 0 + "partitionsPerTopic": self.config.parallel_processes, + "duration": int(self.config.duration / 60) + if self.config.duration > 0 else TEN_YEARS_IN_MINUTES, "charm_root": os.environ.get("CHARM_DIR", ""), } @@ -322,7 +330,7 @@ def get_workload_params(self) -> dict[str, Any]: @override def _render_params( self, - dst_path: str | None = None, + dst_path: str, ) -> str | None: """Render the workload parameters. @@ -344,12 +352,16 @@ def prepare(self) -> bool: # First, clean if a topic already existed self.clean() try: - topic = NewTopic( - name=self.database.state.get().db_name, - num_partitions=self.config.get("threads") * self.config.get("parallel_processes"), - replication_factor=self.client.replication_factor, - ) - self.client.create_topic(topic) + if model := self.database_state.model(): + topic = NewTopic( + name=model.db_name, + num_partitions=self.config.threads * self.config.parallel_processes, + replication_factor=self.client.replication_factor, + ) + self.client.create_topic(topic) + else: + logger.warning("No database model found") + return False except Exception as e: logger.debug(f"Error creating topic: {e}") @@ -360,16 +372,18 @@ def prepare(self) -> bool: def is_prepared(self) -> bool: """Checks if the benchmark service has passed its "prepare" status.""" try: - return self.database.state.get().db_name in self.client._admin_client.list_topics() + if model := self.database_state.model(): + return model.db_name in self.client._admin_client.list_topics() except Exception as e: logger.info(f"Error describing topic: {e}") - return False + return False @override def clean(self) -> bool: """Clean the benchmark service.""" try: - self.client.delete_topics([self.database.state.get().db_name]) + if model := self.database_state.model(): + self.client.delete_topics([model.db_name]) except Exception as e: logger.info(f"Error deleting topic: {e}") return self.is_cleaned() @@ -378,7 +392,9 @@ def clean(self) -> bool: def is_cleaned(self) -> bool: """Checks if the benchmark service has passed its "prepare" status.""" try: - return self.database.state.get().db_name not in self.client._admin_client.list_topics() + if not (model := self.database_state.model()): + return False + return model.db_name not in self.client._admin_client.list_topics() except Exception as e: logger.info(f"Error describing topic: {e}") return False @@ -386,25 +402,57 @@ def is_cleaned(self) -> bool: @cached_property def client(self) -> KafkaClient: """Return the Kafka client.""" - state = self.database.state.get() + if not (state := self.database_state.model()): + return KafkaClient( + servers=[], + username=None, + password=None, + security_protocol="SASL_PLAINTEXT", + replication_factor=1, + ) return KafkaClient( - servers=self.database.bootstrap_servers(), + servers=state.hosts or [], username=state.username, password=state.password, security_protocol="SASL_SSL" if (state.tls or state.tls_ca) else "SASL_PLAINTEXT", cafile_path=state.tls_ca, certfile_path=state.tls, - replication_factor=len(self.peer.units()) + 1, + replication_factor=len(self.peers) + 1, ) +class KafkaBenchmarkActionsHandler(ActionsHandler): + """Handle the actions for the benchmark charm.""" + + def __init__(self, charm: DPBenchmarkCharmBase): + """Initialize the class.""" + super().__init__(charm) + self.config: BenchmarkCharmConfig = charm.config + + @override + def _preflight_checks(self) -> bool: + """Check if we have the necessary relations. + + In kafka case, we need the client relation to be able to connect to the database. + """ + if int(self.config.parallel_processes) < 2: + logger.error("The number of parallel processes must be greater than 1.") + self.unit.status = BlockedStatus( + "The number of parallel processes must be greater than 1." + ) + return False + return self._preflight_checks() + + class KafkaBenchmarkOperator(DPBenchmarkCharmBase): """Charm the service.""" - def __init__(self, *args): - self.workload_params_template = KAFKA_WORKLOAD_PARAMS_TEMPLATE + config_type = KafkaBenchmarkCharmConfig + def __init__(self, *args): super().__init__(*args, db_relation_name=CLIENT_RELATION_NAME) + + self.workload_params_template = KAFKA_WORKLOAD_PARAMS_TEMPLATE self.labels = ",".join([self.model.name, self.unit.name.replace("/", "-")]) self.database = KafkaDatabaseRelationHandler( @@ -414,8 +462,8 @@ def __init__(self, *args): self.peer_handler = KafkaPeersRelationHandler(self, PEER_RELATION) self.config_manager = KafkaConfigManager( workload=self.workload, - database=self.database, - peer=self.peer_handler, + database_state=self.database.state, + peers=self.peer_handler.peers(), config=self.config, labels=self.labels, ) @@ -424,41 +472,23 @@ def __init__(self, *args): self.peers.this_unit(), self.config_manager, ) + self.actions = KafkaBenchmarkActionsHandler(self) self.framework.observe(self.database.on.db_config_update, self._on_config_changed) @override - def _on_install(self, event: EventBase) -> None: + def _on_install(self, _: EventBase) -> None: """Install the charm.""" apt.add_package("openjdk-18-jre", update_cache=True) - @override - def _preflight_checks(self) -> bool: - """Check if we have the necessary relations. - - In kafka case, we need the client relation to be able to connect to the database. - """ - if self.config.get("parallel_processes") < 2: - logger.error("The number of parallel processes must be greater than 1.") - self.unit.status = BlockedStatus( - "The number of parallel processes must be greater than 1." - ) - return False - return super()._preflight_checks() - @override def _on_config_changed(self, event): """Handle the config changed event.""" - if not self._preflight_checks(): + if not self.actions._preflight_checks(): event.defer() return return super()._on_config_changed(event) - @override - def supported_workloads(self) -> list[str]: - """List of supported workloads.""" - return ["default"] - if __name__ == "__main__": ops.main(KafkaBenchmarkOperator) diff --git a/src/models.py b/src/models.py new file mode 100644 index 0000000..92f1651 --- /dev/null +++ b/src/models.py @@ -0,0 +1,33 @@ +# Copyright 2025 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Structured configuration for the Kafka charm.""" + +from pydantic import BaseModel, validator + +from benchmark.core.structured_config import BenchmarkCharmConfig + + +class WorkloadType(BaseModel): + """Workload type parameters.""" + + message_size: int + producer_rate: int + + +WorkloadTypeParameters = { + "default": WorkloadType(message_size=1024, producer_rate=100000), +} + + +class KafkaBenchmarkCharmConfig(BenchmarkCharmConfig): + """Manager for the structured configuration.""" + + @validator("workload_name") + @classmethod + def profile_values(cls, value: str) -> str: + """Check profile config option is valid.""" + if value not in WorkloadTypeParameters.keys(): + raise ValueError(f"Value not one of {str(WorkloadTypeParameters.keys())}") + + return value diff --git a/src/wrapper.py b/src/wrapper.py index bd8ed87..e083844 100755 --- a/src/wrapper.py +++ b/src/wrapper.py @@ -8,8 +8,8 @@ import os import re -from overrides import override from pydantic import BaseModel +from typing_extensions import override from benchmark.literals import BENCHMARK_WORKLOAD_PATH from benchmark.wrapper.core import ( @@ -137,7 +137,7 @@ def _map_run(self) -> tuple[BenchmarkManager | None, list[BenchmarkProcess] | No """Returns the mapping for the run phase.""" driver_path = os.path.join(BENCHMARK_WORKLOAD_PATH, "worker_params.yaml") workload_path = os.path.join(BENCHMARK_WORKLOAD_PATH, "dpe_benchmark.json") - processes = [ + processes: list[BenchmarkProcess] = [ KafkaBenchmarkProcess( model=ProcessModel( cmd=f"""sudo bin/benchmark-worker -p {peer.split(":")[1]} -sp {int(peer.split(":")[1]) + 1}""", diff --git a/tests/unit/test_lifecycle.py b/tests/unit/test_lifecycle.py index 0d28336..8b5e8d1 100644 --- a/tests/unit/test_lifecycle.py +++ b/tests/unit/test_lifecycle.py @@ -45,6 +45,17 @@ def test_next_state_clean(): assert lifecycle.next(DPBenchmarkLifecycleTransition.CLEAN) == DPBenchmarkLifecycleState.UNSET +def test_next_state_stop(): + lifecycle = lifecycle_factory(DPBenchmarkLifecycleState.STOPPED) + lifecycle.config_manager.is_running = MagicMock(return_value=False) + # Check the other condition + assert lifecycle.next(None) is None + + # Test now with the workload recovered + lifecycle.config_manager.is_running = MagicMock(return_value=True) + assert lifecycle.next(None) == DPBenchmarkLifecycleState.RUNNING + + def test_next_state_prepare(): lifecycle = lifecycle_factory(DPBenchmarkLifecycleState.UNSET) assert (