|
66 | 66 | "source": [
|
67 | 67 | "import os\n",
|
68 | 68 | "import string\n",
|
| 69 | + "import tarfile\n", |
69 | 70 | "from pathlib import Path\n",
|
| 71 | + "from typing import Tuple, Callable\n", |
70 | 72 | "from urllib.request import urlretrieve\n",
|
71 | 73 | "\n",
|
72 | 74 | "import numpy as np\n",
|
|
79 | 81 | "from nltk.corpus import stopwords\n",
|
80 | 82 | "from sklearn.metrics import accuracy_score\n",
|
81 | 83 | "from sklearn.model_selection import train_test_split\n",
|
82 |
| - "from typing import Tuple, Callable\n", |
83 | 84 | "\n",
|
84 | 85 | "from giskard import Dataset, Model, scan, testing"
|
85 | 86 | ]
|
|
142 | 143 | "RANDOM_SEED = 0\n",
|
143 | 144 | "\n",
|
144 | 145 | "# Paths.\n",
|
145 |
| - "DATA_URL = \"ftp://sys.giskard.ai/pub/unit_test_resources/fake_real_news_dataset/{}\"\n", |
| 146 | + "DATA_URL = \"https://giskard-library-test-datasets.s3.eu-north-1.amazonaws.com/fake_real_news_dataset-{}\"\n", |
146 | 147 | "DATA_PATH = Path.home() / \".giskard\" / \"fake_real_news_dataset\""
|
147 | 148 | ]
|
148 | 149 | },
|
|
170 | 171 | },
|
171 | 172 | "outputs": [],
|
172 | 173 | "source": [
|
173 |
| - "def fetch_from_ftp(url: str, file: Path) -> None:\n", |
| 174 | + "def fetch_demo_data(url: str, file: Path) -> None:\n", |
174 | 175 | " \"\"\"Helper to fetch data from the FTP server.\"\"\"\n",
|
175 | 176 | " if not file.parent.exists():\n",
|
176 | 177 | " file.parent.mkdir(parents=True, exist_ok=True)\n",
|
|
184 | 185 | "\n",
|
185 | 186 | "def fetch_dataset() -> None:\n",
|
186 | 187 | " \"\"\"Gradually fetch all necessary files from the FTP server.\"\"\"\n",
|
187 |
| - " files_to_fetch = (\"Fake.csv\", \"True.csv\", \"glove_100d.txt\")\n", |
| 188 | + " files_to_fetch = (\"Fake.csv.tar.gz\", \"True.csv.tar.gz\", \"glove_100d.txt.tar.gz\")\n", |
188 | 189 | " for file_name in files_to_fetch:\n",
|
189 |
| - " fetch_from_ftp(DATA_URL.format(file_name), DATA_PATH / file_name)\n", |
| 190 | + " fetch_demo_data(DATA_URL.format(file_name), DATA_PATH / file_name)\n", |
190 | 191 | "\n",
|
191 | 192 | "\n",
|
192 | 193 | "def load_data(**kwargs) -> pd.DataFrame:\n",
|
193 | 194 | " \"\"\"Load data.\"\"\"\n",
|
194 |
| - " real_df = pd.read_csv(DATA_PATH / \"True.csv\", **kwargs)\n", |
195 |
| - " fake_df = pd.read_csv(DATA_PATH / \"Fake.csv\", **kwargs)\n", |
| 195 | + " real_df = pd.read_csv(DATA_PATH / \"True.csv.tar.gz\", **kwargs)\n", |
| 196 | + " fake_df = pd.read_csv(DATA_PATH / \"Fake.csv.tar.gz\", **kwargs)\n", |
196 | 197 | "\n",
|
197 | 198 | " # Create target column.\n",
|
198 | 199 | " real_df[TARGET_COLUMN_NAME] = 0\n",
|
|
380 | 381 | "def get_embeddings_matrix() -> np.ndarray:\n",
|
381 | 382 | " \"\"\"Create matrix, where each row is an embedding of a specific word.\"\"\"\n",
|
382 | 383 | " # Load glove embeddings.\n",
|
383 |
| - " embeddings_dict = dict(parse_line(*line.rstrip().rsplit(' ')) for line in open(DATA_PATH / \"glove_100d.txt\"))\n", |
| 384 | + " embeddings_dict = dict(parse_line(*line.rstrip().rsplit(' ')) for line in tarfile.open(DATA_PATH / \"glove_100d.txt.tar.gz\", \"r:gz\").extractfile(\"fake_real_news_dataset-glove_100d.txt\").read().decode())\n", |
384 | 385 | "\n",
|
385 | 386 | " # Create embeddings matrix with glove word vectors.\n",
|
386 | 387 | " embeddings_matrix = init_embeddings_matrix(embeddings_dict)\n",
|
|
0 commit comments