Skip to content

Commit d47ea28

Browse files
authoredJan 30, 2025··
Merge pull request #2094 from Giskard-AI/feature/gsk-3948-add-numerical-perturbation-detector
[GSK-3948] Add numerical perturbation detector
2 parents 6753a58 + 9d05258 commit d47ea28

8 files changed

+433
-138
lines changed
 
+201-90
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,86 @@
11
from typing import Optional, Sequence
22

3-
from abc import abstractmethod
3+
from abc import ABC, abstractmethod
44

55
import numpy as np
66
import pandas as pd
77

88
from ...datasets.base import Dataset
99
from ...llm import LLMImportError
1010
from ...models.base import BaseModel
11+
from ...models.base.model_prediction import ModelPredictionResults
1112
from ..issues import Issue, IssueLevel, Robustness
1213
from ..logger import logger
1314
from ..registry import Detector
15+
from .base_perturbation_function import PerturbationFunction
16+
from .numerical_transformations import NumericalTransformation
1417
from .text_transformations import TextTransformation
1518

1619

17-
class BaseTextPerturbationDetector(Detector):
18-
"""Base class for metamorphic detectors based on text transformations."""
20+
def _relative_delta(actual: np.ndarray, reference: np.ndarray) -> np.ndarray:
21+
"""
22+
Computes elementwise relative delta. If reference[i] == 0, we replace it with epsilon
23+
to avoid division by zero.
24+
"""
25+
epsilon = 1e-9
26+
safe_ref = np.where(reference == 0, epsilon, reference)
27+
return (actual - reference) / safe_ref
28+
29+
30+
def _get_default_num_samples(model) -> int:
31+
if model.is_text_generation:
32+
return 10
33+
return 1_000
34+
35+
36+
def _get_default_output_sensitivity(model) -> float:
37+
if model.is_text_generation:
38+
return 0.15
39+
return 0.05
40+
41+
42+
def _get_default_threshold(model) -> float:
43+
if model.is_text_generation:
44+
return 0.10
45+
return 0.05
46+
47+
48+
def _generate_robustness_tests(issue: Issue):
49+
from ...testing.tests.metamorphic import test_metamorphic_invariance
50+
51+
# Only generates a single metamorphic test
52+
return {
53+
f"Invariance to “{issue.transformation_fn}”": test_metamorphic_invariance(
54+
transformation_function=issue.transformation_fn,
55+
slicing_function=None,
56+
threshold=1 - issue.meta["threshold"],
57+
output_sensitivity=issue.meta.get("output_sentitivity", None),
58+
)
59+
}
60+
61+
62+
class BasePerturbationDetector(Detector, ABC):
63+
"""
64+
Common parent class for metamorphic perturbation detectors (both text and numerical).
65+
"""
1966

2067
_issue_group = Robustness
2168
_taxonomy = ["avid-effect:performance:P0201"]
2269

2370
def __init__(
2471
self,
25-
transformations: Optional[Sequence[TextTransformation]] = None,
72+
transformations: Optional[Sequence[PerturbationFunction]] = None,
2673
threshold: Optional[float] = None,
27-
output_sensitivity=None,
74+
output_sensitivity: Optional[float] = None,
2875
num_samples: Optional[int] = None,
2976
):
30-
"""Creates a new instance of the detector.
77+
"""
78+
Creates a new instance of the detector.
3179
3280
Parameters
3381
----------
34-
transformations: Optional[Sequence[TextTransformation]]
35-
The text transformations used in the metamorphic testing. See :ref:`transformation_functions` for details
82+
transformations: Optional[Sequence[PerturbationFunction]]
83+
The transformations used in the metamorphic testing. See :ref:`transformation_functions` for details
3684
about the available transformations. If not provided, a default set of transformations will be used.
3785
threshold: Optional[float]
3886
The threshold for the fail rate, which is defined as the proportion of samples for which the model
@@ -52,53 +100,103 @@ def __init__(
52100
self.num_samples = num_samples
53101
self.output_sensitivity = output_sensitivity
54102

55-
def run(self, model: BaseModel, dataset: Dataset, features: Sequence[str]) -> Sequence[Issue]:
56-
transformations = self.transformations or self._get_default_transformations(model, dataset)
103+
@abstractmethod
104+
def _select_features(self, dataset: Dataset, features: Sequence[str]) -> Sequence[str]:
105+
raise NotImplementedError
57106

58-
# Only analyze text features
59-
text_features = [
60-
f
61-
for f in features
62-
if dataset.column_types[f] == "text" and pd.api.types.is_string_dtype(dataset.df[f].dtype)
63-
]
107+
@abstractmethod
108+
def _get_default_transformations(self) -> Sequence[PerturbationFunction]:
109+
raise NotImplementedError
64110

65-
logger.info(
66-
f"{self.__class__.__name__}: Running with transformations={[t.name for t in transformations]} "
67-
f"threshold={self.threshold} output_sensitivity={self.output_sensitivity} num_samples={self.num_samples}"
68-
)
111+
@abstractmethod
112+
def _supports_text_generation(self) -> bool:
113+
raise NotImplementedError
69114

70-
issues = []
71-
for transformation in transformations:
72-
issues.extend(self._detect_issues(model, dataset, transformation, text_features))
115+
def _compute_passed(
116+
self,
117+
model: BaseModel,
118+
original_pred: ModelPredictionResults,
119+
perturbed_pred: ModelPredictionResults,
120+
output_sensitivity: float,
121+
) -> np.ndarray:
122+
if model.is_classification:
123+
return original_pred.raw_prediction == perturbed_pred.raw_prediction
124+
125+
elif model.is_regression:
126+
rel_delta = _relative_delta(perturbed_pred.raw_prediction, original_pred.raw_prediction)
127+
return np.abs(rel_delta) < output_sensitivity
128+
129+
elif model.is_text_generation:
130+
if not self._supports_text_generation():
131+
raise NotImplementedError("Text generation is not supported by this detector.")
132+
try:
133+
import evaluate
134+
except ImportError as err:
135+
raise LLMImportError() from err
136+
137+
scorer = evaluate.load("bertscore")
138+
score = scorer.compute(
139+
predictions=perturbed_pred.prediction,
140+
references=original_pred.prediction,
141+
model_type="distilbert-base-multilingual-cased",
142+
idf=True,
143+
)
144+
return np.array(score["f1"]) > 1 - output_sensitivity
73145

74-
return [i for i in issues if i is not None]
146+
else:
147+
raise NotImplementedError("Only classification, regression, or text generation models are supported.")
75148

76-
@abstractmethod
77-
def _get_default_transformations(self, model: BaseModel, dataset: Dataset) -> Sequence[TextTransformation]:
78-
...
149+
def _create_examples(
150+
self,
151+
original_data: Dataset,
152+
original_pred: ModelPredictionResults,
153+
perturbed_data: Dataset,
154+
perturbed_pred: ModelPredictionResults,
155+
feature: str,
156+
passed: np.ndarray,
157+
model: BaseModel,
158+
transformation_fn,
159+
) -> pd.DataFrame:
160+
examples = original_data.df.loc[~passed, [feature]].copy()
161+
examples[f"{transformation_fn.name}({feature})"] = perturbed_data.df.loc[~passed, feature]
162+
163+
examples["Original prediction"] = original_pred.prediction[~passed]
164+
examples["Prediction after perturbation"] = perturbed_pred.prediction[~passed]
165+
166+
if model.is_classification:
167+
examples["Original prediction"] = examples["Original prediction"].astype(str)
168+
examples["Prediction after perturbation"] = examples["Prediction after perturbation"].astype(str)
169+
170+
ps_before = pd.Series(original_pred.probabilities[~passed], index=examples.index)
171+
ps_after = pd.Series(perturbed_pred.probabilities[~passed], index=examples.index)
172+
173+
examples["Original prediction"] += ps_before.apply(lambda p: f" (p={p:.2f})")
174+
examples["Prediction after perturbation"] += ps_after.apply(lambda p: f" (p={p:.2f})")
175+
176+
return examples
79177

80178
def _detect_issues(
81179
self,
82180
model: BaseModel,
83181
dataset: Dataset,
84-
transformation: TextTransformation,
182+
transformation,
85183
features: Sequence[str],
86184
) -> Sequence[Issue]:
185+
# Fall back to defaults if not explicitly set
87186
num_samples = self.num_samples if self.num_samples is not None else _get_default_num_samples(model)
187+
threshold = self.threshold if self.threshold is not None else _get_default_threshold(model)
88188
output_sensitivity = (
89189
self.output_sensitivity if self.output_sensitivity is not None else _get_default_output_sensitivity(model)
90190
)
91-
threshold = self.threshold if self.threshold is not None else _get_default_threshold(model)
92191

93192
issues = []
94-
# @TODO: integrate this with Giskard metamorphic tests already present
95193
for feature in features:
194+
# Build transformation function for this feature
96195
transformation_fn = transformation(column=feature)
97196
transformed = dataset.transform(transformation_fn)
98197

99198
# Select only the records which were changed
100199
changed_idx = dataset.df.index[transformed.df[feature] != dataset.df[feature]]
101-
102200
if changed_idx.empty:
103201
continue
104202

@@ -107,6 +205,7 @@ def _detect_issues(
107205
rng = np.random.default_rng(747)
108206
changed_idx = changed_idx[rng.choice(len(changed_idx), num_samples, replace=False)]
109207

208+
# Build original vs. perturbed datasets
110209
original_data = Dataset(
111210
dataset.df.loc[changed_idx],
112211
target=dataset.target,
@@ -124,27 +223,12 @@ def _detect_issues(
124223
original_pred = model.predict(original_data)
125224
perturbed_pred = model.predict(perturbed_data)
126225

127-
if model.is_classification:
128-
passed = original_pred.raw_prediction == perturbed_pred.raw_prediction
129-
elif model.is_regression:
130-
rel_delta = _relative_delta(perturbed_pred.raw_prediction, original_pred.raw_prediction)
131-
passed = np.abs(rel_delta) < output_sensitivity
132-
elif model.is_text_generation:
133-
try:
134-
import evaluate
135-
except ImportError as err:
136-
raise LLMImportError() from err
137-
138-
scorer = evaluate.load("bertscore")
139-
score = scorer.compute(
140-
predictions=perturbed_pred.prediction,
141-
references=original_pred.prediction,
142-
model_type="distilbert-base-multilingual-cased",
143-
idf=True,
144-
)
145-
passed = np.array(score["f1"]) > 1 - output_sensitivity
146-
else:
147-
raise NotImplementedError("Only classification, regression, or text generation models are supported.")
226+
passed = self._compute_passed(
227+
model=model,
228+
original_pred=original_pred,
229+
perturbed_pred=perturbed_pred,
230+
output_sensitivity=output_sensitivity,
231+
)
148232

149233
pass_rate = passed.mean()
150234
fail_rate = 1 - pass_rate
@@ -196,61 +280,88 @@ def _detect_issues(
196280
)
197281

198282
# Add examples
199-
examples = original_data.df.loc[~passed, (feature,)].copy()
200-
examples[f"{transformation_fn.name}({feature})"] = perturbed_data.df.loc[~passed, feature]
201-
202-
examples["Original prediction"] = original_pred.prediction[~passed]
203-
examples["Prediction after perturbation"] = perturbed_pred.prediction[~passed]
204-
205-
if model.is_classification:
206-
examples["Original prediction"] = examples["Original prediction"].astype(str)
207-
examples["Prediction after perturbation"] = examples["Prediction after perturbation"].astype(str)
208-
ps_before = pd.Series(original_pred.probabilities[~passed], index=examples.index)
209-
ps_after = pd.Series(perturbed_pred.probabilities[~passed], index=examples.index)
210-
examples["Original prediction"] += ps_before.apply(lambda p: f" (p = {p:.2f})")
211-
examples["Prediction after perturbation"] += ps_after.apply(lambda p: f" (p = {p:.2f})")
212-
283+
examples = self._create_examples(
284+
original_data,
285+
original_pred,
286+
perturbed_data,
287+
perturbed_pred,
288+
feature,
289+
passed,
290+
model,
291+
transformation_fn,
292+
)
213293
issue.add_examples(examples)
214294

215295
issues.append(issue)
216296

217297
return issues
218298

299+
def run(self, model: BaseModel, dataset: Dataset, features: Sequence[str]) -> Sequence[Issue]:
300+
"""
301+
Runs the perturbation detector on the given model and dataset.
219302
220-
def _generate_robustness_tests(issue: Issue):
221-
from ...testing.tests.metamorphic import test_metamorphic_invariance
303+
Parameters
304+
----------
305+
model: BaseModel
306+
The model to test.
307+
dataset: Dataset
308+
The dataset to use for testing.
309+
features: Sequence[str]
310+
The features (columns) to test.
311+
312+
Returns
313+
-------
314+
Sequence[Issue]
315+
A list of issues found during the testing.
316+
"""
317+
transformations = self.transformations or self._get_default_transformations()
318+
selected_features = self._select_features(dataset, features)
222319

223-
# Only generates a single metamorphic test
224-
return {
225-
f"Invariance to “{issue.transformation_fn}”": test_metamorphic_invariance(
226-
transformation_function=issue.transformation_fn,
227-
slicing_function=None,
228-
threshold=1 - issue.meta["threshold"],
229-
output_sensitivity=issue.meta["output_sentitivity"],
320+
logger.info(
321+
f"{self.__class__.__name__}: Running with transformations={[t.name for t in transformations]} "
322+
f"threshold={self.threshold} output_sensitivity={self.output_sensitivity} num_samples={self.num_samples}"
230323
)
231-
}
232324

325+
issues = []
326+
for transformation in transformations:
327+
issues.extend(self._detect_issues(model, dataset, transformation, selected_features))
233328

234-
def _relative_delta(actual, reference):
235-
return (actual - reference) / reference
329+
return [i for i in issues if i is not None]
236330

237331

238-
def _get_default_num_samples(model) -> int:
239-
if model.is_text_generation:
240-
return 10
332+
class BaseTextPerturbationDetector(BasePerturbationDetector):
333+
"""
334+
Base class for metamorphic detectors based on text transformations.
335+
"""
241336

242-
return 1_000
337+
def _select_features(self, dataset: Dataset, features: Sequence[str]) -> Sequence[str]:
338+
# Only analyze text features
339+
return [
340+
f
341+
for f in features
342+
if dataset.column_types[f] == "text" and pd.api.types.is_string_dtype(dataset.df[f].dtype)
343+
]
243344

345+
@abstractmethod
346+
def _get_default_transformations(self) -> Sequence[TextTransformation]:
347+
raise NotImplementedError
244348

245-
def _get_default_output_sensitivity(model) -> float:
246-
if model.is_text_generation:
247-
return 0.15
349+
def _supports_text_generation(self) -> bool:
350+
return True
248351

249-
return 0.05
250352

353+
class BaseNumericalPerturbationDetector(BasePerturbationDetector):
354+
"""
355+
Base class for metamorphic detectors based on numerical feature perturbations.
356+
"""
251357

252-
def _get_default_threshold(model) -> float:
253-
if model.is_text_generation:
254-
return 0.10
358+
def _select_features(self, dataset: Dataset, features: Sequence[str]) -> Sequence[str]:
359+
# Only analyze numeric features
360+
return [f for f in features if dataset.column_types[f] == "numeric"]
255361

256-
return 0.05
362+
@abstractmethod
363+
def _get_default_transformations(self) -> Sequence[NumericalTransformation]:
364+
raise NotImplementedError
365+
366+
def _supports_text_generation(self) -> bool:
367+
return False
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from typing import Any
2+
3+
import pandas as pd
4+
5+
from ...core.core import DatasetProcessFunctionMeta
6+
from ...registry.registry import get_object_uuid
7+
from ...registry.transformation_function import TransformationFunction
8+
9+
10+
class PerturbationFunction(TransformationFunction):
11+
name: str
12+
13+
def __init__(self, column: str, needs_dataset: bool = False) -> None:
14+
super().__init__(None, row_level=False, cell_level=False, needs_dataset=needs_dataset)
15+
self.column = column
16+
self.meta = DatasetProcessFunctionMeta(type="TRANSFORMATION")
17+
self.meta.uuid = get_object_uuid(self)
18+
self.meta.code = self.name
19+
self.meta.name = self.name
20+
self.meta.display_name = self.name
21+
self.meta.tags = ["pickle", "scan"]
22+
self.meta.doc = self.meta.default_doc("Automatically generated transformation function")
23+
24+
def __str__(self) -> str:
25+
return self.name
26+
27+
def make_perturbation(self, data_or_series: Any) -> Any:
28+
raise NotImplementedError()
29+
30+
def execute(self, data: pd.DataFrame) -> pd.DataFrame:
31+
raise NotImplementedError()

‎giskard/scanner/robustness/ethical_bias_detector.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Sequence
22

3-
from ...datasets.base import Dataset
4-
from ...models.base import BaseModel
53
from ..decorators import detector
64
from ..issues import Ethical
75
from .base_detector import BaseTextPerturbationDetector
@@ -28,7 +26,7 @@ class EthicalBiasDetector(BaseTextPerturbationDetector):
2826
_issue_group = Ethical
2927
_taxonomy = ["avid-effect:ethics:E0101", "avid-effect:performance:P0201"]
3028

31-
def _get_default_transformations(self, model: BaseModel, dataset: Dataset) -> Sequence[TextTransformation]:
29+
def _get_default_transformations(self) -> Sequence[TextTransformation]:
3230
from .text_transformations import (
3331
TextGenderTransformation,
3432
TextNationalityTransformation,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Sequence
2+
3+
from ..decorators import detector
4+
from .base_detector import BaseNumericalPerturbationDetector
5+
from .numerical_transformations import NumericalTransformation
6+
7+
8+
class BoundClassWrapper:
9+
def __init__(self, cls, **bound_kwargs):
10+
self.cls = cls
11+
self.bound_kwargs = bound_kwargs
12+
13+
def __call__(self, *args, **kwargs):
14+
return self.cls(*args, **self.bound_kwargs, **kwargs)
15+
16+
def __getattr__(self, attr):
17+
# Forward attribute access to the wrapped class
18+
return getattr(self.cls, attr)
19+
20+
21+
@detector(
22+
name="numerical_perturbation",
23+
tags=[
24+
"numerical_perturbation",
25+
"robustness",
26+
"classification",
27+
"regression",
28+
],
29+
)
30+
class NumericalPerturbationDetector(BaseNumericalPerturbationDetector):
31+
"""Detects robustness problems in a model by applying numerical perturbations to the numerical features."""
32+
33+
def _get_default_transformations(self) -> Sequence[NumericalTransformation]:
34+
from .numerical_transformations import AddGaussianNoise, MultiplyByFactor
35+
36+
return [
37+
BoundClassWrapper(MultiplyByFactor, factor=1.01),
38+
BoundClassWrapper(MultiplyByFactor, factor=0.99),
39+
BoundClassWrapper(AddGaussianNoise, mean=0, std=0.01),
40+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import numpy as np
2+
import pandas as pd
3+
4+
from .base_perturbation_function import PerturbationFunction
5+
6+
7+
class NumericalTransformation(PerturbationFunction):
8+
def __init__(self, column: str, needs_dataset: bool = False) -> None:
9+
super().__init__(column, needs_dataset=needs_dataset)
10+
11+
def execute(self, data: pd.DataFrame) -> pd.DataFrame:
12+
feature_data = data[self.column].dropna()
13+
data.loc[feature_data.index, self.column] = self.make_perturbation(feature_data)
14+
return data
15+
16+
17+
class MultiplyByFactor(NumericalTransformation):
18+
name = "Multiply by factor"
19+
20+
def __init__(self, column: str, factor: float) -> None:
21+
super().__init__(column)
22+
self.factor = factor
23+
24+
def make_perturbation(self, values: pd.Series) -> pd.Series:
25+
# Round if the column is an integer type
26+
if np.issubdtype(values.dtype, np.integer):
27+
return np.round(values * self.factor).astype(values.dtype)
28+
return values * self.factor
29+
30+
31+
class AddGaussianNoise(NumericalTransformation):
32+
name = "Add Gaussian noise"
33+
34+
def __init__(self, column: str, mean: float = 0, std: float = 0.01, rng_seed: int = 1729) -> None:
35+
super().__init__(column)
36+
self.mean = mean
37+
self.std = std
38+
self.rng = np.random.default_rng(seed=rng_seed)
39+
40+
def make_perturbation(self, values: pd.Series) -> pd.Series:
41+
noise = self.rng.normal(self.mean, self.std, values.shape)
42+
if np.issubdtype(values.dtype, np.integer):
43+
return np.round(values + noise).astype(values.dtype)
44+
return values + noise

‎giskard/scanner/robustness/text_perturbation_detector.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from typing import Sequence
22

3-
from ...datasets.base import Dataset
4-
from ...models.base import BaseModel
53
from ..decorators import detector
64
from .base_detector import BaseTextPerturbationDetector
75
from .text_transformations import TextTransformation
@@ -25,7 +23,7 @@ class TextPerturbationDetector(BaseTextPerturbationDetector):
2523
e.g. transforming to uppercase, lowercase, or title case, or by introducing typos.
2624
"""
2725

28-
def _get_default_transformations(self, model: BaseModel, dataset: Dataset) -> Sequence[TextTransformation]:
26+
def _get_default_transformations(self) -> Sequence[TextTransformation]:
2927
from .text_transformations import (
3028
TextAccentRemovalTransformation,
3129
TextLowercase,

‎giskard/scanner/robustness/text_transformations.py

+26-42
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Any, Optional
2+
13
import itertools
24
import json
35
import re
@@ -8,38 +10,20 @@
810
import pandas as pd
911
from num2words import num2words
1012

11-
from ...core.core import DatasetProcessFunctionMeta
1213
from ...datasets import Dataset
1314
from ...functions.transformation import gruber
14-
from ...registry.registry import get_object_uuid
15-
from ...registry.transformation_function import TransformationFunction
16-
15+
from .base_perturbation_function import PerturbationFunction
1716

18-
class TextTransformation(TransformationFunction):
19-
name: str
2017

21-
def __init__(self, column, needs_dataset=False):
22-
super().__init__(None, row_level=False, cell_level=False, needs_dataset=needs_dataset)
23-
self.column = column
24-
self.meta = DatasetProcessFunctionMeta(type="TRANSFORMATION")
25-
self.meta.uuid = get_object_uuid(self)
26-
self.meta.code = self.name
27-
self.meta.name = self.name
28-
self.meta.display_name = self.name
29-
self.meta.tags = ["pickle", "scan"]
30-
self.meta.doc = self.meta.default_doc("Automatically generated transformation function")
31-
32-
def __str__(self):
33-
return self.name
18+
class TextTransformation(PerturbationFunction):
19+
def __init__(self, column: str, needs_dataset: bool = False) -> None:
20+
super().__init__(column, needs_dataset=needs_dataset)
3421

3522
def execute(self, data: pd.DataFrame) -> pd.DataFrame:
3623
feature_data = data[self.column].dropna().astype(str)
3724
data.loc[feature_data.index, self.column] = feature_data.apply(self.make_perturbation)
3825
return data
3926

40-
def make_perturbation(self, text: str) -> str:
41-
raise NotImplementedError()
42-
4327

4428
class TextUppercase(TextTransformation):
4529
name = "Transform to uppercase"
@@ -71,7 +55,7 @@ def execute(self, data: pd.DataFrame) -> pd.DataFrame:
7155
class TextTypoTransformation(TextTransformation):
7256
name = "Add typos"
7357

74-
def __init__(self, column, rate=0.05, min_length=10, rng_seed=1729):
58+
def __init__(self, column: str, rate: float = 0.05, min_length: int = 10, rng_seed: int = 1729):
7559
super().__init__(column)
7660
from .entity_swap import typos
7761

@@ -80,7 +64,7 @@ def __init__(self, column, rate=0.05, min_length=10, rng_seed=1729):
8064
self._key_typos = typos
8165
self.rng = np.random.default_rng(seed=rng_seed)
8266

83-
def make_perturbation(self, x):
67+
def make_perturbation(self, x: str) -> str:
8468
# Skip if the text is too short
8569
if len(x) < self.min_length:
8670
return x
@@ -118,7 +102,7 @@ def make_perturbation(self, x):
118102
x = x[:i] + x[i + 1] + x[i] + x[i + 2 :]
119103
return x
120104

121-
def _random_key_typo(self, char):
105+
def _random_key_typo(self, char: str):
122106
if char.lower() in self._key_typos:
123107
typo = self.rng.choice(self._key_typos[char.lower()])
124108
return typo if char.islower() else typo.upper()
@@ -128,7 +112,7 @@ def _random_key_typo(self, char):
128112
class TextFromOCRTypoTransformation(TextTransformation):
129113
name = "Add typos from OCR"
130114

131-
def __init__(self, column, rate=0.05, min_length=10, rng_seed=1729):
115+
def __init__(self, column: str, rate: float = 0.05, min_length: int = 10, rng_seed: int = 1729):
132116
super().__init__(column)
133117
from .entity_swap import ocr_typos
134118

@@ -137,7 +121,7 @@ def __init__(self, column, rate=0.05, min_length=10, rng_seed=1729):
137121
self._ocr_typos = ocr_typos
138122
self.rng = np.random.default_rng(seed=rng_seed)
139123

140-
def make_perturbation(self, x):
124+
def make_perturbation(self, x: str) -> str:
141125
# Check if the input is None
142126
if x is None:
143127
return None
@@ -165,7 +149,7 @@ def make_perturbation(self, x):
165149
x = x[:i] + x[i + 1 :]
166150
return x
167151

168-
def _random_ocr_typo(self, char):
152+
def _random_ocr_typo(self, char: str) -> str:
169153
if char.lower() in self._ocr_typos:
170154
typo = self.rng.choice(self._ocr_typos[char.lower()])
171155
return typo if char.islower() else typo.upper()
@@ -182,7 +166,7 @@ def __init__(self, *args, **kwargs):
182166
self._trans_table = str.maketrans("", "", self._punctuation)
183167
self._regex = re.compile(rf"\b[{re.escape(self._punctuation)}]+\b")
184168

185-
def make_perturbation(self, text):
169+
def make_perturbation(self, text: str) -> str:
186170
# Split URLs so that they are not affected by the transformation
187171
pieces = gruber.split(text)
188172

@@ -198,12 +182,12 @@ def make_perturbation(self, text):
198182
class TextAccentRemovalTransformation(TextTransformation):
199183
name = "Accent Removal"
200184

201-
def __init__(self, column, rate=1.0, rng_seed=1729):
185+
def __init__(self, column: str, rate: float = 1.0, rng_seed: int = 1729):
202186
super().__init__(column)
203187
self.rate = rate
204188
self.rng = np.random.default_rng(seed=rng_seed)
205189

206-
def make_perturbation(self, text):
190+
def make_perturbation(self, text: str) -> str:
207191
return "".join(
208192
char
209193
for char in unicodedata.normalize("NFD", text)
@@ -212,7 +196,7 @@ def make_perturbation(self, text):
212196

213197

214198
class TextLanguageBasedTransformation(TextTransformation):
215-
def __init__(self, column, rng_seed=1729):
199+
def __init__(self, column: str, rng_seed: int = 1729):
216200
super().__init__(column, needs_dataset=True)
217201
self._lang_dictionary = dict()
218202
self._load_dictionaries()
@@ -228,13 +212,13 @@ def execute(self, dataset: Dataset) -> pd.DataFrame:
228212
dataset.df.loc[feature_data.index, self.column] = feature_data.apply(self.make_perturbation, axis=1)
229213
return dataset.df
230214

231-
def make_perturbation(self, row):
215+
def make_perturbation(self, row: pd.Series) -> Any:
232216
raise NotImplementedError()
233217

234-
def _switch(self, word, language):
218+
def _switch(self, word: str, language: str) -> Optional[tuple[str, str]]:
235219
raise NotImplementedError()
236220

237-
def _select_dict(self, language):
221+
def _select_dict(self, language: str) -> Optional[Any]:
238222
try:
239223
return self._lang_dictionary[language]
240224
except KeyError:
@@ -249,7 +233,7 @@ def _load_dictionaries(self):
249233

250234
self._lang_dictionary = {"en": gender_switch_en, "fr": gender_switch_fr}
251235

252-
def make_perturbation(self, row):
236+
def make_perturbation(self, row: pd.Series) -> str:
253237
text = row[self.column]
254238
language = row["language__gsk__meta"]
255239

@@ -265,7 +249,7 @@ def make_perturbation(self, row):
265249

266250
return new_text
267251

268-
def _switch(self, word, language):
252+
def _switch(self, word: str, language: str) -> Optional[tuple[str, str]]:
269253
try:
270254
return (word, self._lang_dictionary[language][word.lower()])
271255
except KeyError:
@@ -279,7 +263,7 @@ def _load_dictionaries(self):
279263
# Regex to match numbers in text
280264
self._regex = re.compile(r"(?<!\d/)(?<!\d\.)\b\d+(?:\.\d+)?\b(?!(?:\.\d+)?@|\d?/?\d)")
281265

282-
def make_perturbation(self, row):
266+
def make_perturbation(self, row: pd.Series) -> str:
283267
# Replace numbers with words
284268
value = row[self.column]
285269
if pd.isna(value):
@@ -307,7 +291,7 @@ def _load_dictionaries(self):
307291

308292
self._lang_dictionary = {"en": religion_dict_en, "fr": religion_dict_fr}
309293

310-
def make_perturbation(self, row):
294+
def make_perturbation(self, row: pd.Series) -> str:
311295
# Get text
312296
text = row[self.column]
313297

@@ -345,7 +329,7 @@ def _load_dictionaries(self):
345329
nationalities_dict = json.load(f)
346330
self._lang_dictionary = {"en": nationalities_dict["en"], "fr": nationalities_dict["fr"]}
347331

348-
def make_perturbation(self, row):
332+
def make_perturbation(self, row: pd.Series) -> str:
349333
text = row[self.column]
350334
language = row["language__gsk__meta"]
351335
nationalities_word_dict = self._select_dict(language)
@@ -381,7 +365,7 @@ def make_perturbation(self, row):
381365
class TextFromSpeechTypoTransformation(TextLanguageBasedTransformation):
382366
name = "Add text from speech typos"
383367

384-
def __init__(self, column, rng_seed=1729, min_length=10):
368+
def __init__(self, column: str, rng_seed: int = 1729, min_length: int = 10):
385369
super().__init__(column, rng_seed=rng_seed)
386370

387371
self.min_length = min_length
@@ -391,7 +375,7 @@ def _load_dictionaries(self):
391375

392376
self._word_typos = speech_typos
393377

394-
def make_perturbation(self, row):
378+
def make_perturbation(self, row: pd.Series) -> str:
395379
text = row[self.column]
396380
language = row["language__gsk__meta"]
397381

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
import giskard
6+
from giskard.scanner.robustness.numerical_perturbation_detector import NumericalPerturbationDetector
7+
8+
9+
class MockClassificationModel:
10+
def predict(self, df):
11+
# Randomly assign predictions, introducing some variability
12+
return np.random.choice([0, 1], size=len(df))
13+
14+
15+
class MockRegressionModel:
16+
def predict(self, df):
17+
# For simplicity, use a linear relationship plus some noise
18+
return 2 * df["feature_1"] + 3 * df["feature_2"] + np.random.normal(0, 5, len(df))
19+
20+
21+
def test_numerical_perturbation_classification():
22+
# Creating a simple mock classification dataset
23+
df = pd.DataFrame(
24+
{"feature_1": [1.0, 2.0, 3.0, 4.0, 5.0], "feature_2": [10.0, 20.0, 30.0, 40.0, 50.0], "target": [0, 1, 1, 0, 0]}
25+
)
26+
dataset = giskard.Dataset(df=df, target="target", column_types={"feature_1": "numeric", "feature_2": "numeric"})
27+
28+
# Creating a mock model with some variability in predictions
29+
model = giskard.Model(MockClassificationModel().predict, model_type="classification", classification_labels=[0, 1])
30+
31+
# Running the Numerical Perturbation Detector
32+
analyzer = NumericalPerturbationDetector(threshold=0.01)
33+
issues = analyzer.run(model, dataset, features=["feature_1", "feature_2"])
34+
35+
assert issues # Ensure that the detector identifies some issues
36+
37+
38+
def test_numerical_perturbation_skips_non_numerical_dtypes():
39+
# Mock dataset with a text feature, but declared as numeric
40+
df = pd.DataFrame({"feature": ["a", "b", "c", "d", "e"], "target": [0, 1, 0, 1, 0]})
41+
dataset = giskard.Dataset(df, target="target", column_types={"feature": "text"})
42+
43+
# Creating a mock model that always predicts 1
44+
model = giskard.Model(lambda df: np.ones(len(df)), model_type="classification", classification_labels=[0, 1])
45+
46+
# Running the Numerical Perturbation Detector
47+
analyzer = NumericalPerturbationDetector(threshold=0.001, output_sensitivity=1.0, num_samples=100)
48+
issues = analyzer.run(model, dataset, features=["feature"])
49+
50+
assert not issues # Since the feature is non-numeric, no issues should be detected
51+
52+
53+
def test_numerical_perturbation_works_with_nan_values():
54+
# Mock dataset with NaN values in numeric feature
55+
df = pd.DataFrame({"feature": [1.0, 2.0, np.nan, 4.0, 5.0], "target": [0, 1, 0, 1, 0]})
56+
dataset = giskard.Dataset(df, target="target", column_types={"feature": "numeric"})
57+
58+
# Creating a mock model with some variability in predictions
59+
model = giskard.Model(
60+
lambda df: np.random.choice([0, 1], size=len(df)), model_type="classification", classification_labels=[0, 1]
61+
)
62+
63+
# Running the Numerical Perturbation Detector
64+
analyzer = NumericalPerturbationDetector(threshold=0.01)
65+
issues = analyzer.run(model, dataset, features=["feature"])
66+
67+
assert issues # Ensure that the detector identifies some issues
68+
69+
70+
@pytest.mark.memory_expensive
71+
def test_numerical_perturbation_on_regression():
72+
# Mock regression dataset
73+
df = pd.DataFrame(
74+
{
75+
"feature_1": [1.0, 2.0, 3.0, 4.0, 5.0],
76+
"feature_2": [10.0, 20.0, 30.0, 40.0, 50.0],
77+
"target": [15.0, 25.0, 35.0, 45.0, 55.0],
78+
}
79+
)
80+
dataset = giskard.Dataset(df, target="target", column_types={"feature_1": "numeric", "feature_2": "numeric"})
81+
82+
# Creating a mock model with a linear relationship
83+
model = giskard.Model(MockRegressionModel().predict, model_type="regression")
84+
85+
# Running the Numerical Perturbation Detector
86+
analyzer = NumericalPerturbationDetector(threshold=0.01, output_sensitivity=0.1)
87+
issues = analyzer.run(model, dataset, features=["feature_1", "feature_2"])
88+
89+
assert issues # Ensure that the detector identifies some issues

0 commit comments

Comments
 (0)
Please sign in to comment.