Skip to content

Commit 4061f08

Browse files
keueugene-kulakyevhenii-ldvgirarda
authoredDec 21, 2023
CDK: Add schema normalization to declarative stream (#32786)
Co-authored-by: Eugene Kulak <kulak.eugene@gmail.com> Co-authored-by: Yevhenii Kurochkin <ykurochkin@flyaps.com> Co-authored-by: Alexandre Girard <alexandre@airbyte.io>
1 parent 08c2da2 commit 4061f08

14 files changed

+212
-48
lines changed
 

‎airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_component_schema.yaml

+17-6
Original file line numberDiff line numberDiff line change
@@ -1203,12 +1203,10 @@ definitions:
12031203
http_method:
12041204
title: HTTP Method
12051205
description: The HTTP method used to fetch data from the source (can be GET or POST).
1206-
anyOf:
1207-
- type: string
1208-
- type: string
1209-
enum:
1210-
- GET
1211-
- POST
1206+
type: string
1207+
enum:
1208+
- GET
1209+
- POST
12121210
default: GET
12131211
examples:
12141212
- GET
@@ -1822,9 +1820,22 @@ definitions:
18221820
title: Record Filter
18231821
description: Responsible for filtering records to be emitted by the Source.
18241822
"$ref": "#/definitions/RecordFilter"
1823+
schema_normalization:
1824+
"$ref": "#/definitions/SchemaNormalization"
1825+
default: None
18251826
$parameters:
18261827
type: object
18271828
additionalProperties: true
1829+
SchemaNormalization:
1830+
title: Schema Normalization
1831+
description: Responsible for normalization according to the schema.
1832+
type: string
1833+
enum:
1834+
- None
1835+
- Default
1836+
examples:
1837+
- None
1838+
- Default
18281839
RemoveFields:
18291840
title: Remove Fields
18301841
description: A transformation which removes fields from a record. The fields removed are designated using FieldPointers. During transformation, if a field or any of its parents does not exist in the record, no error is thrown.

‎airbyte-cdk/python/airbyte_cdk/sources/declarative/declarative_stream.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def read_records(
101101
"""
102102
:param: stream_state We knowingly avoid using stream_state as we want cursors to manage their own state.
103103
"""
104-
yield from self.retriever.read_records(stream_slice)
104+
yield from self.retriever.read_records(self.get_json_schema(), stream_slice)
105105

106106
def get_json_schema(self) -> Mapping[str, Any]: # type: ignore
107107
"""

‎airbyte-cdk/python/airbyte_cdk/sources/declarative/extractors/http_selector.py

+2
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ def select_records(
2222
self,
2323
response: requests.Response,
2424
stream_state: StreamState,
25+
records_schema: Mapping[str, Any],
2526
stream_slice: Optional[StreamSlice] = None,
2627
next_page_token: Optional[Mapping[str, Any]] = None,
2728
) -> List[Record]:
2829
"""
2930
Selects records from the response
3031
:param response: The response to select the records from
3132
:param stream_state: The stream state
33+
:param records_schema: json schema of records to return
3234
:param stream_slice: The stream slice
3335
:param next_page_token: The paginator token
3436
:return: List of Records selected from the response

‎airbyte-cdk/python/airbyte_cdk/sources/declarative/extractors/record_selector.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,15 @@
99
from airbyte_cdk.sources.declarative.extractors.http_selector import HttpSelector
1010
from airbyte_cdk.sources.declarative.extractors.record_extractor import RecordExtractor
1111
from airbyte_cdk.sources.declarative.extractors.record_filter import RecordFilter
12+
from airbyte_cdk.sources.declarative.models import SchemaNormalization
1213
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
1314
from airbyte_cdk.sources.declarative.types import Config, Record, StreamSlice, StreamState
15+
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
16+
17+
SCHEMA_TRANSFORMER_TYPE_MAPPING = {
18+
SchemaNormalization.None_: TransformConfig.NoTransform,
19+
SchemaNormalization.Default: TransformConfig.DefaultSchemaNormalization,
20+
}
1421

1522

1623
@dataclass
@@ -21,13 +28,15 @@ class RecordSelector(HttpSelector):
2128
2229
Attributes:
2330
extractor (RecordExtractor): The record extractor responsible for extracting records from a response
31+
schema_normalization (TypeTransformer): The record normalizer responsible for casting record values to stream schema types
2432
record_filter (RecordFilter): The record filter responsible for filtering extracted records
2533
transformations (List[RecordTransformation]): The transformations to be done on the records
2634
"""
2735

2836
extractor: RecordExtractor
2937
config: Config
3038
parameters: InitVar[Mapping[str, Any]]
39+
schema_normalization: TypeTransformer
3140
record_filter: Optional[RecordFilter] = None
3241
transformations: List[RecordTransformation] = field(default_factory=lambda: [])
3342

@@ -38,14 +47,31 @@ def select_records(
3847
self,
3948
response: requests.Response,
4049
stream_state: StreamState,
50+
records_schema: Mapping[str, Any],
4151
stream_slice: Optional[StreamSlice] = None,
4252
next_page_token: Optional[Mapping[str, Any]] = None,
4353
) -> List[Record]:
54+
"""
55+
Selects records from the response
56+
:param response: The response to select the records from
57+
:param stream_state: The stream state
58+
:param records_schema: json schema of records to return
59+
:param stream_slice: The stream slice
60+
:param next_page_token: The paginator token
61+
:return: List of Records selected from the response
62+
"""
4463
all_data = self.extractor.extract_records(response)
4564
filtered_data = self._filter(all_data, stream_state, stream_slice, next_page_token)
4665
self._transform(filtered_data, stream_state, stream_slice)
66+
self._normalize_by_schema(filtered_data, schema=records_schema)
4767
return [Record(data, stream_slice) for data in filtered_data]
4868

69+
def _normalize_by_schema(self, records: List[Mapping[str, Any]], schema: Optional[Mapping[str, Any]]) -> List[Mapping[str, Any]]:
70+
if schema:
71+
# record has type Mapping[str, Any], but dict[str, Any] expected
72+
return [self.schema_normalization.transform(record, schema) for record in records] # type: ignore
73+
return records
74+
4975
def _filter(
5076
self,
5177
records: List[Mapping[str, Any]],
@@ -67,4 +93,5 @@ def _transform(
6793
) -> None:
6894
for record in records:
6995
for transformation in self.transformations:
70-
transformation.transform(record, config=self.config, stream_state=stream_state, stream_slice=stream_slice)
96+
# record has type Mapping[str, Any], but Record expected
97+
transformation.transform(record, config=self.config, stream_state=stream_state, stream_slice=stream_slice) # type: ignore

‎airbyte-cdk/python/airbyte_cdk/sources/declarative/models/declarative_component_schema.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ class SessionTokenRequestBearerAuthenticator(BaseModel):
340340
type: Literal['Bearer']
341341

342342

343-
class HttpMethodEnum(Enum):
343+
class HttpMethod(Enum):
344344
GET = 'GET'
345345
POST = 'POST'
346346

@@ -572,6 +572,11 @@ class RecordFilter(BaseModel):
572572
parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters')
573573

574574

575+
class SchemaNormalization(Enum):
576+
None_ = 'None'
577+
Default = 'Default'
578+
579+
575580
class RemoveFields(BaseModel):
576581
type: Literal['RemoveFields']
577582
field_pointers: List[List[str]] = Field(
@@ -1019,6 +1024,7 @@ class RecordSelector(BaseModel):
10191024
description='Responsible for filtering records to be emitted by the Source.',
10201025
title='Record Filter',
10211026
)
1027+
schema_normalization: Optional[SchemaNormalization] = SchemaNormalization.None_
10221028
parameters: Optional[Dict[str, Any]] = Field(None, alias='$parameters')
10231029

10241030

@@ -1232,8 +1238,8 @@ class HttpRequester(BaseModel):
12321238
description='Error handler component that defines how to handle errors.',
12331239
title='Error Handler',
12341240
)
1235-
http_method: Optional[Union[str, HttpMethodEnum]] = Field(
1236-
'GET',
1241+
http_method: Optional[HttpMethod] = Field(
1242+
HttpMethod.GET,
12371243
description='The HTTP method used to fetch data from the source (can be GET or POST).',
12381244
examples=['GET', 'POST'],
12391245
title='HTTP Method',

‎airbyte-cdk/python/airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from airbyte_cdk.sources.declarative.declarative_stream import DeclarativeStream
2727
from airbyte_cdk.sources.declarative.decoders import JsonDecoder
2828
from airbyte_cdk.sources.declarative.extractors import DpathExtractor, RecordFilter, RecordSelector
29+
from airbyte_cdk.sources.declarative.extractors.record_selector import SCHEMA_TRANSFORMER_TYPE_MAPPING
2930
from airbyte_cdk.sources.declarative.incremental import Cursor, CursorFactory, DatetimeBasedCursor, PerPartitionCursor
3031
from airbyte_cdk.sources.declarative.interpolation import InterpolatedString
3132
from airbyte_cdk.sources.declarative.interpolation.interpolated_mapping import InterpolatedMapping
@@ -107,6 +108,7 @@
107108
from airbyte_cdk.sources.declarative.requesters.request_option import RequestOptionType
108109
from airbyte_cdk.sources.declarative.requesters.request_options import InterpolatedRequestOptionsProvider
109110
from airbyte_cdk.sources.declarative.requesters.request_path import RequestPath
111+
from airbyte_cdk.sources.declarative.requesters.requester import HttpMethod
110112
from airbyte_cdk.sources.declarative.retrievers import SimpleRetriever, SimpleRetrieverTestReadDecorator
111113
from airbyte_cdk.sources.declarative.schema import DefaultSchemaLoader, InlineSchemaLoader, JsonFileSchemaLoader
112114
from airbyte_cdk.sources.declarative.spec import Spec
@@ -115,6 +117,7 @@
115117
from airbyte_cdk.sources.declarative.transformations.add_fields import AddedFieldDefinition
116118
from airbyte_cdk.sources.declarative.types import Config
117119
from airbyte_cdk.sources.message import InMemoryMessageRepository, LogAppenderMessageRepositoryDecorator, MessageRepository
120+
from airbyte_cdk.sources.utils.transform import TypeTransformer
118121
from isodate import parse_duration
119122
from pydantic import BaseModel
120123

@@ -710,9 +713,8 @@ def create_http_requester(self, model: HttpRequesterModel, config: Config, *, na
710713
parameters=model.parameters or {},
711714
)
712715

713-
model_http_method = (
714-
model.http_method if isinstance(model.http_method, str) else model.http_method.value if model.http_method is not None else "GET"
715-
)
716+
assert model.use_cache is not None # for mypy
717+
assert model.http_method is not None # for mypy
716718

717719
assert model.use_cache is not None # for mypy
718720

@@ -722,7 +724,7 @@ def create_http_requester(self, model: HttpRequesterModel, config: Config, *, na
722724
path=model.path,
723725
authenticator=authenticator,
724726
error_handler=error_handler,
725-
http_method=model_http_method,
727+
http_method=HttpMethod[model.http_method.value],
726728
request_options_provider=request_options_provider,
727729
config=config,
728730
disable_retries=self._disable_retries,
@@ -884,16 +886,24 @@ def create_request_option(model: RequestOptionModel, config: Config, **kwargs: A
884886
return RequestOption(field_name=model.field_name, inject_into=inject_into, parameters={})
885887

886888
def create_record_selector(
887-
self, model: RecordSelectorModel, config: Config, *, transformations: List[RecordTransformation], **kwargs: Any
889+
self,
890+
model: RecordSelectorModel,
891+
config: Config,
892+
*,
893+
transformations: List[RecordTransformation],
894+
**kwargs: Any,
888895
) -> RecordSelector:
896+
assert model.schema_normalization is not None # for mypy
889897
extractor = self._create_component_from_model(model=model.extractor, config=config)
890898
record_filter = self._create_component_from_model(model.record_filter, config=config) if model.record_filter else None
899+
schema_normalization = TypeTransformer(SCHEMA_TRANSFORMER_TYPE_MAPPING[model.schema_normalization])
891900

892901
return RecordSelector(
893902
extractor=extractor,
894903
config=config,
895904
record_filter=record_filter,
896905
transformations=transformations,
906+
schema_normalization=schema_normalization,
897907
parameters=model.parameters or {},
898908
)
899909

‎airbyte-cdk/python/airbyte_cdk/sources/declarative/requesters/http_requester.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class HttpRequester(Requester):
5959
config: Config
6060
parameters: InitVar[Mapping[str, Any]]
6161
authenticator: Optional[DeclarativeAuthenticator] = None
62-
http_method: Union[str, HttpMethod] = HttpMethod.GET
62+
http_method: HttpMethod = HttpMethod.GET
6363
request_options_provider: Optional[InterpolatedRequestOptionsProvider] = None
6464
error_handler: Optional[ErrorHandler] = None
6565
disable_retries: bool = False
@@ -80,7 +80,6 @@ def __post_init__(self, parameters: Mapping[str, Any]) -> None:
8080
else:
8181
self._request_options_provider = self.request_options_provider
8282
self._authenticator = self.authenticator or NoAuth(parameters=parameters)
83-
self._http_method = HttpMethod[self.http_method] if isinstance(self.http_method, str) else self.http_method
8483
self.error_handler = self.error_handler
8584
self._parameters = parameters
8685
self.decoder = JsonDecoder(parameters={})
@@ -139,7 +138,7 @@ def get_path(
139138
return path.lstrip("/")
140139

141140
def get_method(self) -> HttpMethod:
142-
return self._http_method
141+
return self.http_method
143142

144143
def interpret_response_status(self, response: requests.Response) -> ResponseStatus:
145144
if self.error_handler is None:
@@ -420,7 +419,7 @@ def _create_prepared_request(
420419
data: Any = None,
421420
) -> requests.PreparedRequest:
422421
url = urljoin(self.get_url_base(), path)
423-
http_method = str(self._http_method.value)
422+
http_method = str(self.http_method.value)
424423
query_params = self.deduplicate_query_params(url, params)
425424
args = {"method": http_method, "url": url, "headers": headers, "params": query_params}
426425
if http_method.upper() in BODY_REQUEST_METHODS:

‎airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/retriever.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from abc import abstractmethod
66
from dataclasses import dataclass
7-
from typing import Iterable, Optional
7+
from typing import Any, Iterable, Mapping, Optional
88

99
from airbyte_cdk.sources.declarative.types import StreamSlice, StreamState
1010
from airbyte_cdk.sources.streams.core import StreamData
@@ -19,15 +19,14 @@ class Retriever:
1919
@abstractmethod
2020
def read_records(
2121
self,
22+
records_schema: Mapping[str, Any],
2223
stream_slice: Optional[StreamSlice] = None,
2324
) -> Iterable[StreamData]:
2425
"""
2526
Fetch a stream's records from an HTTP API source
2627
27-
:param sync_mode: Unused but currently necessary for integrating with HttpStream
28-
:param cursor_field: Unused but currently necessary for integrating with HttpStream
28+
:param records_schema: json schema to describe record
2929
:param stream_slice: The stream slice to read data for
30-
:param stream_state: The initial stream state
3130
:return: The records read from the API source
3231
"""
3332

‎airbyte-cdk/python/airbyte_cdk/sources/declarative/retrievers/simple_retriever.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44

55
from dataclasses import InitVar, dataclass, field
6+
from functools import partial
67
from itertools import islice
78
from typing import Any, Callable, Iterable, List, Mapping, Optional, Set, Tuple, Union
89

@@ -215,6 +216,7 @@ def _parse_response(
215216
self,
216217
response: Optional[requests.Response],
217218
stream_state: StreamState,
219+
records_schema: Mapping[str, Any],
218220
stream_slice: Optional[StreamSlice] = None,
219221
next_page_token: Optional[Mapping[str, Any]] = None,
220222
) -> Iterable[Record]:
@@ -225,7 +227,11 @@ def _parse_response(
225227

226228
self._last_response = response
227229
records = self.record_selector.select_records(
228-
response=response, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
230+
response=response,
231+
stream_state=stream_state,
232+
records_schema=records_schema,
233+
stream_slice=stream_slice,
234+
next_page_token=next_page_token,
229235
)
230236
self._records_from_last_response = records
231237
return records
@@ -271,16 +277,15 @@ def _fetch_next_page(
271277
# This logic is similar to _read_pages in the HttpStream class. When making changes here, consider making changes there as well.
272278
def _read_pages(
273279
self,
274-
records_generator_fn: Callable[[Optional[requests.Response], Mapping[str, Any], Mapping[str, Any]], Iterable[StreamData]],
280+
records_generator_fn: Callable[[Optional[requests.Response]], Iterable[StreamData]],
275281
stream_state: Mapping[str, Any],
276282
stream_slice: Mapping[str, Any],
277283
) -> Iterable[StreamData]:
278-
stream_state = stream_state or {}
279284
pagination_complete = False
280285
next_page_token = None
281286
while not pagination_complete:
282287
response = self._fetch_next_page(stream_state, stream_slice, next_page_token)
283-
yield from records_generator_fn(response, stream_state, stream_slice)
288+
yield from records_generator_fn(response)
284289

285290
if not response:
286291
pagination_complete = True
@@ -294,14 +299,28 @@ def _read_pages(
294299

295300
def read_records(
296301
self,
302+
records_schema: Mapping[str, Any],
297303
stream_slice: Optional[StreamSlice] = None,
298304
) -> Iterable[StreamData]:
305+
"""
306+
Fetch a stream's records from an HTTP API source
307+
308+
:param records_schema: json schema to describe record
309+
:param stream_slice: The stream slice to read data for
310+
:return: The records read from the API source
311+
"""
299312
stream_slice = stream_slice or {} # None-check
300313
# Fixing paginator types has a long tail of dependencies
301314
self._paginator.reset()
302315

303316
most_recent_record_from_slice = None
304-
for stream_data in self._read_pages(self._parse_records, self.state, stream_slice):
317+
record_generator = partial(
318+
self._parse_records,
319+
stream_state=self.state or {},
320+
stream_slice=stream_slice,
321+
records_schema=records_schema,
322+
)
323+
for stream_data in self._read_pages(record_generator, self.state, stream_slice):
305324
most_recent_record_from_slice = self._get_most_recent_record(most_recent_record_from_slice, stream_data, stream_slice)
306325
yield stream_data
307326

@@ -361,9 +380,15 @@ def _parse_records(
361380
self,
362381
response: Optional[requests.Response],
363382
stream_state: Mapping[str, Any],
383+
records_schema: Mapping[str, Any],
364384
stream_slice: Optional[Mapping[str, Any]],
365385
) -> Iterable[StreamData]:
366-
yield from self._parse_response(response, stream_slice=stream_slice, stream_state=stream_state)
386+
yield from self._parse_response(
387+
response,
388+
stream_slice=stream_slice,
389+
stream_state=stream_state,
390+
records_schema=records_schema,
391+
)
367392

368393
def must_deduplicate_query_params(self) -> bool:
369394
return True

‎airbyte-cdk/python/bin/generate-component-manifest-files.sh

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ function main() {
1919
--input "/airbyte/$YAML_DIR/$filename_wo_ext.yaml" \
2020
--output "/airbyte/$OUTPUT_DIR/$filename_wo_ext.py" \
2121
--disable-timestamp \
22-
--enum-field-as-literal one
22+
--enum-field-as-literal one \
23+
--set-default-enum-member
2324

2425
# There is a limitation of Pydantic where a model's private fields starting with an underscore are inaccessible.
2526
# The Pydantic model generator replaces special characters like $ with the underscore which results in all

‎airbyte-cdk/python/unit_tests/sources/declarative/extractors/test_record_selector.py

+80-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from airbyte_cdk.sources.declarative.extractors.record_selector import RecordSelector
1414
from airbyte_cdk.sources.declarative.transformations import RecordTransformation
1515
from airbyte_cdk.sources.declarative.types import Record
16+
from airbyte_cdk.sources.utils.transform import TransformConfig, TypeTransformer
1617

1718

1819
@pytest.mark.parametrize(
@@ -68,6 +69,7 @@ def test_record_filter(test_name, field_path, filter_template, body, expected_da
6869
stream_state = {"created_at": "06-06-21"}
6970
stream_slice = {"last_seen": "06-10-21"}
7071
next_page_token = {"last_seen_id": 14}
72+
schema = create_schema()
7173
first_transformation = Mock(spec=RecordTransformation)
7274
second_transformation = Mock(spec=RecordTransformation)
7375
transformations = [first_transformation, second_transformation]
@@ -80,13 +82,19 @@ def test_record_filter(test_name, field_path, filter_template, body, expected_da
8082
else:
8183
record_filter = RecordFilter(config=config, condition=filter_template, parameters=parameters)
8284
record_selector = RecordSelector(
83-
extractor=extractor, record_filter=record_filter, transformations=transformations, config=config, parameters=parameters
85+
extractor=extractor,
86+
record_filter=record_filter,
87+
transformations=transformations,
88+
config=config,
89+
parameters=parameters,
90+
schema_normalization=TypeTransformer(TransformConfig.NoTransform),
8491
)
8592

8693
actual_records = record_selector.select_records(
87-
response=response, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
94+
response=response, records_schema=schema, stream_state=stream_state, stream_slice=stream_slice, next_page_token=next_page_token
8895
)
8996
assert actual_records == [Record(data, stream_slice) for data in expected_data]
97+
9098
calls = []
9199
for record in expected_data:
92100
calls.append(call(record, config=config, stream_state=stream_state, stream_slice=stream_slice))
@@ -95,7 +103,77 @@ def test_record_filter(test_name, field_path, filter_template, body, expected_da
95103
transformation.transform.assert_has_calls(calls)
96104

97105

106+
@pytest.mark.parametrize(
107+
"test_name, schema, schema_transformation, body, expected_data",
108+
[
109+
(
110+
"test_with_empty_schema",
111+
{},
112+
TransformConfig.NoTransform,
113+
{"data": [{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}]},
114+
[{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}],
115+
),
116+
(
117+
"test_with_schema_none_normalizer",
118+
{},
119+
TransformConfig.NoTransform,
120+
{"data": [{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}]},
121+
[{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}],
122+
),
123+
(
124+
"test_with_schema_and_default_normalizer",
125+
{},
126+
TransformConfig.DefaultSchemaNormalization,
127+
{"data": [{"id": 1, "created_at": "06-06-21", "field_int": "100", "field_float": "123.3"}]},
128+
[{"id": "1", "created_at": "06-06-21", "field_int": 100, "field_float": 123.3}],
129+
),
130+
],
131+
)
132+
def test_schema_normalization(test_name, schema, schema_transformation, body, expected_data):
133+
config = {"response_override": "stop_if_you_see_me"}
134+
parameters = {"parameters_field": "data", "created_at": "06-07-21"}
135+
stream_state = {"created_at": "06-06-21"}
136+
stream_slice = {"last_seen": "06-10-21"}
137+
next_page_token = {"last_seen_id": 14}
138+
139+
response = create_response(body)
140+
schema = create_schema()
141+
decoder = JsonDecoder(parameters={})
142+
extractor = DpathExtractor(field_path=["data"], decoder=decoder, config=config, parameters=parameters)
143+
record_selector = RecordSelector(
144+
extractor=extractor,
145+
record_filter=None,
146+
transformations=[],
147+
config=config,
148+
parameters=parameters,
149+
schema_normalization=TypeTransformer(schema_transformation),
150+
)
151+
152+
actual_records = record_selector.select_records(
153+
response=response,
154+
stream_state=stream_state,
155+
stream_slice=stream_slice,
156+
next_page_token=next_page_token,
157+
records_schema=schema,
158+
)
159+
160+
assert actual_records == [Record(data, stream_slice) for data in expected_data]
161+
162+
98163
def create_response(body):
99164
response = requests.Response()
100165
response._content = json.dumps(body).encode("utf-8")
101166
return response
167+
168+
169+
def create_schema():
170+
return {
171+
"$schema": "http://json-schema.org/draft-07/schema#",
172+
"type": "object",
173+
"properties": {
174+
"id": {"type": "string"},
175+
"created_at": {"type": "string"},
176+
"field_int": {"type": "integer"},
177+
"field_float": {"type": "number"},
178+
},
179+
}

‎airbyte-cdk/python/unit_tests/sources/declarative/parsers/test_model_to_component_factory.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def test_full_config_stream():
245245
assert stream.retriever.paginator.pagination_strategy.page_size == 10
246246

247247
assert isinstance(stream.retriever.requester, HttpRequester)
248-
assert stream.retriever.requester._http_method == HttpMethod.GET
248+
assert stream.retriever.requester.http_method == HttpMethod.GET
249249
assert stream.retriever.requester.name == stream.name
250250
assert stream.retriever.requester._path.string == "{{ next_page_token['next_page_url'] }}"
251251
assert stream.retriever.requester._path.default == "{{ next_page_token['next_page_url'] }}"
@@ -829,7 +829,7 @@ def test_create_requester(test_name, error_handler, expected_backoff_strategy_ty
829829
)
830830

831831
assert isinstance(selector, HttpRequester)
832-
assert selector._http_method == HttpMethod.GET
832+
assert selector.http_method == HttpMethod.GET
833833
assert selector.name == "name"
834834
assert selector._path.string == "/v3/marketing/lists"
835835
assert selector._url_base.string == "https://api.sendgrid.com"
@@ -1075,7 +1075,7 @@ def test_config_with_defaults():
10751075
assert stream.schema_loader.file_path.default == "./source_sendgrid/schemas/{{ parameters.name }}.yaml"
10761076

10771077
assert isinstance(stream.retriever.requester, HttpRequester)
1078-
assert stream.retriever.requester._http_method == HttpMethod.GET
1078+
assert stream.retriever.requester.http_method == HttpMethod.GET
10791079

10801080
assert isinstance(stream.retriever.requester.authenticator, BearerAuthenticator)
10811081
assert stream.retriever.requester.authenticator.token_provider.get_token() == "verysecrettoken"

‎airbyte-cdk/python/unit_tests/sources/declarative/requesters/test_http_requester.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def factory(
6161

6262

6363
def test_http_requester():
64-
http_method = "GET"
64+
http_method = HttpMethod.GET
6565

6666
request_options_provider = MagicMock()
6767
request_params = {"param": "value"}
@@ -106,7 +106,7 @@ def test_http_requester():
106106
assert requester.get_url_base() == "https://airbyte.io/"
107107
assert requester.get_path(stream_state={}, stream_slice=stream_slice, next_page_token={}) == "v1/1234"
108108
assert requester.get_authenticator() == authenticator
109-
assert requester.get_method() == HttpMethod.GET
109+
assert requester.get_method() == http_method
110110
assert requester.get_request_params(stream_state={}, stream_slice=None, next_page_token=None) == request_params
111111
assert requester.get_request_body_data(stream_state={}, stream_slice=None, next_page_token=None) == request_body_data
112112
assert requester.get_request_body_json(stream_state={}, stream_slice=None, next_page_token=None) == request_body_json

‎airbyte-cdk/python/unit_tests/sources/declarative/retrievers/test_simple_retriever.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
4-
54
from unittest.mock import MagicMock, Mock, patch
65

76
import pytest
@@ -94,7 +93,7 @@ def test_simple_retriever_full(mock_http_stream):
9493

9594
assert retriever._last_response is None
9695
assert retriever._records_from_last_response == []
97-
assert retriever._parse_response(response, stream_state={}) == records
96+
assert retriever._parse_response(response, stream_state={}, records_schema={}) == records
9897
assert retriever._last_response == response
9998
assert retriever._records_from_last_response == records
10099

@@ -170,7 +169,7 @@ def test_simple_retriever_with_request_response_log_last_records(mock_http_strea
170169

171170
assert retriever._last_response is None
172171
assert retriever._records_from_last_response == []
173-
assert retriever._parse_response(response, stream_state={}) == request_response_logs
172+
assert retriever._parse_response(response, stream_state={}, records_schema={}) == request_response_logs
174173
assert retriever._last_response == response
175174
assert retriever._records_from_last_response == request_response_logs
176175

@@ -396,13 +395,16 @@ def test_when_read_records_then_cursor_close_slice_with_greater_record(test_name
396395
)
397396
stream_slice = {"repository": "airbyte"}
398397

398+
def retriever_read_pages(_, __, ___):
399+
return retriever._parse_records(response=MagicMock(), stream_state={}, stream_slice=stream_slice, records_schema={})
400+
399401
with patch.object(
400402
SimpleRetriever,
401403
"_read_pages",
402404
return_value=iter([first_record, second_record]),
403-
side_effect=lambda _, __, ___: retriever._parse_records(response=MagicMock(), stream_state=None, stream_slice=stream_slice),
405+
side_effect=retriever_read_pages,
404406
):
405-
list(retriever.read_records(stream_slice=stream_slice))
407+
list(retriever.read_records(stream_slice=stream_slice, records_schema={}))
406408
cursor.close_slice.assert_called_once_with(stream_slice, first_record if first_greater_than_second else second_record)
407409

408410

@@ -425,13 +427,16 @@ def test_given_stream_data_is_not_record_when_read_records_then_update_slice_wit
425427
)
426428
stream_slice = {"repository": "airbyte"}
427429

430+
def retriever_read_pages(_, __, ___):
431+
return retriever._parse_records(response=MagicMock(), stream_state={}, stream_slice=stream_slice, records_schema={})
432+
428433
with patch.object(
429434
SimpleRetriever,
430435
"_read_pages",
431436
return_value=iter(stream_data),
432-
side_effect=lambda _, __, ___: retriever._parse_records(response=MagicMock(), stream_state=None, stream_slice=stream_slice),
437+
side_effect=retriever_read_pages,
433438
):
434-
list(retriever.read_records(stream_slice=stream_slice))
439+
list(retriever.read_records(stream_slice=stream_slice, records_schema={}))
435440
cursor.close_slice.assert_called_once_with(stream_slice, None)
436441

437442

@@ -440,7 +445,7 @@ def _generate_slices(number_of_slices):
440445

441446

442447
@patch.object(SimpleRetriever, "_read_pages", return_value=iter([]))
443-
def test_given_state_selector_when_read_records_use_stream_state(http_stream_read_pages):
448+
def test_given_state_selector_when_read_records_use_stream_state(http_stream_read_pages, mocker):
444449
requester = MagicMock()
445450
paginator = MagicMock()
446451
record_selector = MagicMock()
@@ -459,9 +464,10 @@ def test_given_state_selector_when_read_records_use_stream_state(http_stream_rea
459464
parameters={},
460465
config={},
461466
)
462-
list(retriever.read_records(stream_slice=A_STREAM_SLICE))
463467

464-
http_stream_read_pages.assert_called_once_with(retriever._parse_records, A_STREAM_STATE, A_STREAM_SLICE)
468+
list(retriever.read_records(stream_slice=A_STREAM_SLICE, records_schema={}))
469+
470+
http_stream_read_pages.assert_called_once_with(mocker.ANY, A_STREAM_STATE, A_STREAM_SLICE)
465471

466472

467473
def test_emit_log_request_response_messages(mocker):

0 commit comments

Comments
 (0)
Please sign in to comment.