Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion drift/instrumentation/psycopg/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
TuskDriftMode,
)
from ..base import InstrumentationBase
from ..utils.psycopg_utils import deserialize_db_value
from ..utils.psycopg_utils import deserialize_db_value, restore_row_integer_types
from ..utils.serialization import serialize_value
from .mocks import MockConnection, MockCopy
from .wrappers import TracedCopyWrapper
Expand Down Expand Up @@ -1663,6 +1663,7 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any], is_asy
mock_rows = actual_data.get("rows", [])
# Deserialize datetime strings back to datetime objects for consistent Flask serialization
mock_rows = [deserialize_db_value(row) for row in mock_rows]
mock_rows = [restore_row_integer_types(row, description_data) for row in mock_rows]
cursor._mock_rows = mock_rows # pyright: ignore[reportAttributeAccessIssue]
cursor._mock_index = 0 # pyright: ignore[reportAttributeAccessIssue]

Expand Down Expand Up @@ -1731,6 +1732,7 @@ def _mock_executemany_returning_with_data(self, cursor: Any, mock_data: dict[str
# Deserialize rows
mock_rows = result_set.get("rows", [])
mock_rows = [deserialize_db_value(row) for row in mock_rows]
mock_rows = [restore_row_integer_types(row, description_data) for row in mock_rows]

cursor._mock_result_sets.append( # pyright: ignore[reportAttributeAccessIssue]
{
Expand Down
3 changes: 2 additions & 1 deletion drift/instrumentation/psycopg2/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
TuskDriftMode,
)
from ..base import InstrumentationBase
from ..utils.psycopg_utils import deserialize_db_value
from ..utils.psycopg_utils import deserialize_db_value, restore_row_integer_types
from ..utils.serialization import serialize_value

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -890,6 +890,7 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non
mock_rows = actual_data.get("rows", [])
# Deserialize datetime strings back to datetime objects for consistent Flask/Django serialization
mock_rows = [deserialize_db_value(row) for row in mock_rows]
mock_rows = [restore_row_integer_types(row, description_data) for row in mock_rows]

# Check if this is a dict-cursor (like RealDictCursor)
# First check if cursor has _is_dict_cursor attribute (set by InstrumentedConnection.cursor())
Expand Down
68 changes: 68 additions & 0 deletions drift/instrumentation/utils/psycopg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
from decimal import Decimal
from typing import Any

# PostgreSQL integer type OIDs
# These are the type codes for integer columns in PostgreSQL
# See: https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_type.dat
POSTGRES_INTEGER_TYPE_CODES = {
20, # BIGINT (int8)
21, # SMALLINT (int2)
23, # INTEGER (int4)
26, # OID
28, # XID
}

# Try to import psycopg Range type for deserialization support
try:
from psycopg.types.range import Range as PsycopgRange # type: ignore[import-untyped]
Expand Down Expand Up @@ -86,3 +97,60 @@ def deserialize_db_value(val: Any) -> Any:
elif isinstance(val, list):
return [deserialize_db_value(v) for v in val]
return val


def restore_row_integer_types(
row: list[Any] | dict[str, Any], description: list[dict[str, Any]] | None
) -> list[Any] | dict[str, Any]:
"""Restore integer types for database row values using column metadata.

During the record/replay cycle, integer values are lost due to JSON serialization:
- Recording: PostgreSQL INTEGER column → psycopg2 returns int(0) -> JSON stores 0
- Replay: CLI parses JSON -> Go float64(0) -> protobuf double -> Python float(0.0)

This function uses the column type_code from the cursor description to identify
which columns should contain integers and converts whole-number floats back to int.

Args:
row: A row of values from the mocked database query. Can be a list (standard cursor)
or dict (dict cursor like RealDictCursor).
description: Column metadata from the cursor description, containing 'type_code' for each column.
Format: [{"name": "col1", "type_code": 23}, ...]

Returns:
The row with integer types restored for INTEGER columns.
"""
if not description or not row:
return row

# Handle dict rows (from dict cursors like RealDictCursor)
if isinstance(row, dict):
# Build a mapping of column names to type codes
type_code_by_name = {}
for col in description:
if isinstance(col, dict):
col_name = col.get("name")
if col_name:
type_code_by_name[col_name] = col.get("type_code")

result = {}
for key, value in row.items():
type_code = type_code_by_name.get(key)
if type_code in POSTGRES_INTEGER_TYPE_CODES and isinstance(value, float) and value.is_integer():
result[key] = int(value)
else:
result[key] = value
return result

# Handle list/tuple rows (standard cursors)
result = []
for i, value in enumerate(row):
if i < len(description):
type_code = description[i].get("type_code") if isinstance(description[i], dict) else None
if type_code in POSTGRES_INTEGER_TYPE_CODES and isinstance(value, float) and value.is_integer():
result.append(int(value))
else:
result.append(value)
else:
result.append(value)
return result