Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion airbyte_cdk/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,10 @@ def airbyte_message_to_string(airbyte_message: AirbyteMessage) -> str:
f"There was an error during the serialization of an AirbyteMessage: `{exception}`. This might impact the sync performances."
)
_HAS_LOGGED_FOR_SERIALIZATION_ERROR = True
return json.dumps(serialized_message)
try:
return json.dumps(serialized_message)
except Exception:
return json.dumps(serialized_message, default=str)

@classmethod
def extract_state(cls, args: List[str]) -> Optional[Any]:
Expand Down
22 changes: 22 additions & 0 deletions unit_tests/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,3 +856,25 @@ def test_given_serialization_error_using_orjson_then_fallback_on_json(
# There will be multiple messages here because the fixture `entrypoint` sets a control message. We only care about records here
record_messages = list(filter(lambda message: "RECORD" in message, messages))
assert len(record_messages) == 2


def test_given_non_json_serializable_type_then_fallback_with_default_str(
entrypoint: AirbyteEntrypoint, mocker, spec_mock, config_mock
):
"""Test that types which both orjson and json cannot serialize (like complex) are handled via default=str fallback."""
parsed_args = Namespace(
command="read", config="config_path", state="statepath", catalog="catalogpath"
)
record = AirbyteMessage(
record=AirbyteRecordMessage(stream="stream", data={"value": complex(1, 2)}, emitted_at=1),
type=Type.RECORD,
)
mocker.patch.object(MockSource, "read_state", return_value={})
mocker.patch.object(MockSource, "read_catalog", return_value={})
mocker.patch.object(MockSource, "read", return_value=[record])

messages = list(entrypoint.run(parsed_args))

record_messages = list(filter(lambda message: "RECORD" in message, messages))
assert len(record_messages) == 1
assert "(1+2j)" in record_messages[0]
Loading