Skip to content

Commit 6cdde2f

Browse files
authored
Merge pull request #2039 from Giskard-AI/feature/gsk-3827-load-scan-test-suite-doesnt-work
[GSK-3827] Fix load/save giskard Dataset
2 parents dcf36fb + b5561dd commit 6cdde2f

File tree

2 files changed

+76
-13
lines changed

2 files changed

+76
-13
lines changed

giskard/datasets/base/__init__.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
if TYPE_CHECKING:
3535
from mlflow import MlflowClient
3636

37-
SAMPLE_SIZE = 1000
38-
3937
logger = logging.getLogger(__name__)
4038

4139

@@ -526,10 +524,22 @@ def cast_column_to_dtypes(df, column_dtypes):
526524

527525
@classmethod
528526
def load(cls, local_path: str):
529-
with open(local_path, "rb") as ds_stream:
530-
return pd.read_csv(
531-
ZstdDecompressor().stream_reader(ds_stream), keep_default_na=False, na_values=["_GSK_NA_"]
532-
)
527+
# load metadata
528+
with open(Path(local_path) / "giskard-dataset-meta.yaml", "r") as meta_f:
529+
meta = yaml.safe_load(meta_f)
530+
531+
# load data
532+
with open(Path(local_path) / "data.csv.zst", "rb") as ds_stream:
533+
df = pd.read_csv(ZstdDecompressor().stream_reader(ds_stream), keep_default_na=False, na_values=["_GSK_NA_"])
534+
535+
return cls(
536+
df,
537+
name=meta.get("name"),
538+
target=meta.get("target"),
539+
cat_columns=[k for k in meta["category_features"].keys()],
540+
column_types=meta.get("column_types"),
541+
original_id=meta.get("id"),
542+
)
533543

534544
@staticmethod
535545
def _cat_columns(meta):
@@ -543,21 +553,17 @@ def _cat_columns(meta):
543553
def cat_columns(self):
544554
return self._cat_columns(self.meta)
545555

546-
def save(self, local_path: Path, dataset_id):
547-
with open(local_path / "data.csv.zst", "wb") as f, open(local_path / "data.sample.csv.zst", "wb") as f_sample:
556+
def save(self, local_path: str):
557+
with (open(Path(local_path) / "data.csv.zst", "wb") as f,):
548558
uncompressed_bytes = save_df(self.df)
549559
compressed_bytes = compress(uncompressed_bytes)
550560
f.write(compressed_bytes)
551561
original_size_bytes, compressed_size_bytes = len(uncompressed_bytes), len(compressed_bytes)
552562

553-
uncompressed_bytes = save_df(self.df.sample(min(SAMPLE_SIZE, len(self.df.index))))
554-
compressed_bytes = compress(uncompressed_bytes)
555-
f_sample.write(compressed_bytes)
556-
557563
with open(Path(local_path) / "giskard-dataset-meta.yaml", "w") as meta_f:
558564
yaml.dump(
559565
{
560-
"id": dataset_id,
566+
"id": str(self.id),
561567
"name": self.meta.name,
562568
"target": self.meta.target,
563569
"column_types": self.meta.column_types,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import tempfile
2+
3+
import pandas as pd
4+
import pytest
5+
6+
from giskard.datasets import Dataset
7+
8+
9+
@pytest.mark.parametrize(
10+
"dataset",
11+
[
12+
Dataset(
13+
pd.DataFrame(
14+
{
15+
"question": [
16+
"What is the capital of France?",
17+
"What is the capital of Germany?",
18+
]
19+
}
20+
),
21+
column_types={"question": "text"},
22+
target=None,
23+
),
24+
Dataset(
25+
pd.DataFrame(
26+
{
27+
"country": ["France", "Germany", "France", "Germany", "France"],
28+
"capital": ["Paris", "Berlin", "Paris", "Berlin", "Paris"],
29+
}
30+
),
31+
column_types={"country": "category", "capital": "category"},
32+
cat_columns=["country", "capital"],
33+
target=None,
34+
),
35+
Dataset(
36+
pd.DataFrame(
37+
{
38+
"x": [1, 2, 3, 4, 5],
39+
"y": [2, 4, 6, 8, 10],
40+
}
41+
),
42+
column_types={"x": "numeric", "y": "numeric"},
43+
target="y",
44+
),
45+
],
46+
ids=["text", "category", "numeric"],
47+
)
48+
def test_save_and_load_dataset(dataset: Dataset):
49+
with tempfile.TemporaryDirectory() as tmp_test_folder:
50+
dataset.save(tmp_test_folder)
51+
52+
loaded_dataset = Dataset.load(tmp_test_folder)
53+
54+
assert loaded_dataset.id != dataset.id
55+
assert loaded_dataset.original_id == dataset.id
56+
assert pd.DataFrame.equals(loaded_dataset.df, dataset.df)
57+
assert loaded_dataset.meta == dataset.meta

0 commit comments

Comments
 (0)