Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,13 @@ def get_table_comment(
@calculate_execution_time()
def _get_all_relation_info(self, connection, **kw): # pylint: disable=unused-argument
"""
Get all relation info for a schema.

Uses a custom single-schema cache instead of @reflection.cache
to prevent unbounded memory growth across schemas (issue #20649).
Only the most recently requested schema's data is retained.
The ``table_name`` kwarg is not used for filtering since the
cache is keyed by schema only.
"""
# pylint: disable=consider-using-f-string
schema = kw.get("schema", None)
Expand All @@ -442,15 +447,10 @@ def _get_all_relation_info(self, connection, **kw): # pylint: disable=unused-ar

schema_clause = "AND schema = '{schema}'".format(schema=schema) if schema else ""

table_name = kw.get("table_name", None)
table_clause = (
"AND relname = '{table}'".format(table=table_name) if table_name else ""
)

result = connection.execute(
sa.text(
REDSHIFT_GET_ALL_RELATIONS.format(
schema_clause=schema_clause, table_clause=table_clause, limit_clause=""
schema_clause=schema_clause, table_clause="", limit_clause=""
)
)
)
Expand Down
79 changes: 76 additions & 3 deletions ingestion/tests/unit/topology/database/test_redshift_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
import unittest
from unittest.mock import MagicMock, Mock

from metadata.ingestion.source.database.redshift.utils import get_view_definition
from metadata.ingestion.source.database.redshift.utils import (
_get_all_relation_info,
get_view_definition,
)


class TestRedshiftUtils(unittest.TestCase):
"""Test Redshift Utils"""
class TestGetViewDefinition(unittest.TestCase):
"""Test get_view_definition formatting and prefix handling"""

def setUp(self):
"""Set up test fixtures"""
Expand Down Expand Up @@ -229,5 +232,75 @@ def test_external_view_definition_removes_schema_binding(self):
)


class TestGetAllRelationInfoCache(unittest.TestCase):
"""Test _get_all_relation_info single-schema cache"""

@staticmethod
def _make_relation(relname, schema):
rel = Mock()
rel.relname = relname
rel.schema = schema
return rel

@staticmethod
def _make_result(relations):
result = MagicMock()
result.__iter__ = Mock(return_value=iter(relations))
return result

def setUp(self):
self.mock_self = Mock(spec=[])
self.mock_connection = Mock()

def test_cache_returns_all_relations_regardless_of_table_name(self):
"""Passing table_name must not filter the cached results."""
relations = [
self._make_relation("view_a", "my_schema"),
self._make_relation("view_b", "my_schema"),
self._make_relation("table_c", "my_schema"),
]
self.mock_connection.execute.return_value = self._make_result(relations)

result = _get_all_relation_info(
self.mock_self,
self.mock_connection,
schema="my_schema",
table_name="view_a",
)

self.assertEqual(len(result), 3)
self.assertEqual({k.name for k in result}, {"view_a", "view_b", "table_c"})

# Second lookup for a different table must return the same cached dict
result2 = _get_all_relation_info(
self.mock_self,
self.mock_connection,
schema="my_schema",
table_name="view_b",
)

self.assertIs(result, result2)
self.mock_connection.execute.assert_called_once()

def test_cache_invalidates_on_schema_change(self):
"""Moving to a new schema must replace the cached data."""
self.mock_connection.execute.side_effect = [
self._make_result([self._make_relation("t1", "schema_1")]),
self._make_result([self._make_relation("t2", "schema_2")]),
]

r1 = _get_all_relation_info(
self.mock_self, self.mock_connection, schema="schema_1"
)
self.assertEqual({k.name for k in r1}, {"t1"})

r2 = _get_all_relation_info(
self.mock_self, self.mock_connection, schema="schema_2"
)
self.assertEqual({k.name for k in r2}, {"t2"})

self.assertEqual(self.mock_connection.execute.call_count, 2)


if __name__ == "__main__":
unittest.main()
Loading