Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ebcb1c7
feat: replace cursor.execute() with run_sql utility for executing SQL…
vinay79n Dec 2, 2025
28bd3ec
bump version
vinay79n Dec 2, 2025
f84e8c3
feat: added some copilot code cirrection suggetions and for the criti…
vinay79n Dec 2, 2025
657276c
lint fix
vinay79n Dec 2, 2025
ea578dc
feat: add --no-pylint option to test command in pandas read/write flo…
vinay79n Dec 2, 2025
01b3d68
fix: refactor query_pandas_from_snowflake to use run_sql for executin…
vinay79n Dec 3, 2025
489a1a6
fix: remove --no-pylint option from test command in pandas read/write…
vinay79n Dec 3, 2025
529642d
chore: removed old commented code
vinay79n Dec 3, 2025
fc67c37
lint fix
vinay79n Dec 3, 2025
8927f14
feat: add integration style tests for run_sql integration with publis…
vinay79n Dec 3, 2025
4a933d6
feat: rename run_sql with _execute_sql and moved it to _snowflake pri…
vinay79n Dec 3, 2025
8669655
feat: move _execute_sql to shared module and update references across…
vinay79n Dec 3, 2025
3eae576
lint fix
vinay79n Dec 3, 2025
6306d8f
feat: update _execute_sql integration tests to handle comment-only SQ…
vinay79n Dec 3, 2025
99c792c
fix: remove comment only sql test
vinay79n Dec 3, 2025
3248700
removed redundent test
vinay79n Dec 3, 2025
605c474
feat: renamed generic "shared" module name. Adder error handling
avr2002 Dec 4, 2025
d89ff79
feat: refactor _execute_sql import and error handling accessing metaf…
avr2002 Dec 4, 2025
e957fea
refactor: using updated module name for _execute_sql function
avr2002 Dec 4, 2025
66314b0
test: removed redundant tests and adding new unit test for _execute_s…
avr2002 Dec 4, 2025
dcae2a5
fix: correct workflow name
avr2002 Dec 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci-cd-ds-platform-utils.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Publish DS Projen
name: Publish DS Platform Utils

on:
workflow_dispatch:
Expand All @@ -16,7 +16,7 @@ jobs:
- name: Checkout Repository
uses: actions/checkout@v4
with:
fetch-depth: 0 # Fetch all history for version tagging
fetch-depth: 0 # Fetch all history for version tagging

- name: Set up uv
uses: astral-sh/setup-uv@v5
Expand Down Expand Up @@ -44,7 +44,7 @@ jobs:
cache-dependency-glob: "${{ github.workspace }}/uv.lock"

- name: Run pre-commit hooks
run: SKIP=no-commit-to-branch uv run poe lint # using poethepoet needs to be setup before using poe lint
run: SKIP=no-commit-to-branch uv run poe lint # using poethepoet needs to be setup before using poe lint

build-wheel:
name: Build Wheel
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ds-platform-utils"
version = "0.2.3"
version = "0.3.0"
description = "Utility library for Pattern Data Science."
readme = "README.md"
authors = [
Expand Down
41 changes: 41 additions & 0 deletions src/ds_platform_utils/_snowflake/run_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Shared Snowflake utility functions."""

import warnings
from typing import Iterable, Optional

from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor
from snowflake.connector.errors import ProgrammingError


def _execute_sql(conn: SnowflakeConnection, sql: str) -> Optional[SnowflakeCursor]:
"""Execute SQL statement(s) using Snowflake's ``connection.execute_string()`` and return the *last* resulting cursor.

Snowflake's ``execute_string`` allows a single string containing multiple SQL
statements (separated by semicolons) to be executed at once. Unlike
``cursor.execute()``, which handles exactly one statement and returns a single
cursor object, ``execute_string`` returns a **list of cursors**—one cursor for each
individual SQL statement in the batch.

:param conn: Snowflake connection object
:param sql: SQL query or batch of semicolon-delimited SQL statements
:return: The cursor corresponding to the last executed statement, or None if no
statements were executed or if the SQL contains only whitespace/comments
"""
if not sql.strip():
return None

try:
cursors: Iterable[SnowflakeCursor] = conn.execute_string(sql.strip())

if cursors is None:
return None

*_, last = cursors
return last
except ProgrammingError as e:
if "Empty SQL statement" in str(e):
# raise a warning and return None
warnings.warn("Empty SQL statement encountered; returning None.", category=UserWarning, stacklevel=2)
return None
raise
20 changes: 15 additions & 5 deletions src/ds_platform_utils/_snowflake/write_audit_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jinja2 import DebugUndefined, Template
from snowflake.connector.cursor import SnowflakeCursor

from ds_platform_utils._snowflake.run_query import _execute_sql
from ds_platform_utils.metaflow._consts import NON_PROD_SCHEMA, PROD_SCHEMA


Expand Down Expand Up @@ -200,8 +201,8 @@ def run_query(query: str, cursor: Optional[SnowflakeCursor] = None) -> None:
print(f"Would execute query:\n{query}")
return

# Count statements so we can tell Snowflake exactly how many to expect
cursor.execute(query, num_statements=0) # 0 means any number of statements
# run the query using _execute_sql utility which handles multiple statements via execute_string
_execute_sql(cursor.connection, query)
cursor.connection.commit()


Expand All @@ -216,7 +217,10 @@ def run_audit_query(query: str, cursor: Optional[SnowflakeCursor] = None) -> dic
if cursor is None:
return {"mock_result": True}

cursor.execute(query)
cursor = _execute_sql(cursor.connection, query)
if cursor is None:
return {}

result = cursor.fetchone()
if not result:
return {}
Expand All @@ -243,11 +247,17 @@ def fetch_table_preview(
if not cursor:
return [{"mock_col": "mock_val"}]

cursor.execute(f"""
cursor = _execute_sql(
cursor.connection,
f"""
SELECT *
FROM {database}.{schema}.{table_name}
LIMIT {n_rows};
""")
""",
)
if cursor is None:
return []

columns = [col[0] for col in cursor.description]
rows = cursor.fetchall()
return [dict(zip(columns, row)) for row in rows]
Expand Down
18 changes: 12 additions & 6 deletions src/ds_platform_utils/metaflow/get_snowflake_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from metaflow import Snowflake, current
from snowflake.connector import SnowflakeConnection

from ds_platform_utils._snowflake.run_query import _execute_sql

####################
# --- Metaflow --- #
####################
Expand Down Expand Up @@ -41,7 +43,12 @@ def get_snowflake_connection(
In metaflow, each step is a separate Python process, so the connection will automatically be
closed at the end of any steps that use this singleton.
"""
return _create_snowflake_connection(use_utc=use_utc, query_tag=current.project_name)
if current and hasattr(current, "project_name"):
query_tag = current.project_name
else:
query_tag = None

return _create_snowflake_connection(use_utc=use_utc, query_tag=query_tag)


#####################
Expand All @@ -66,11 +73,10 @@ def _create_snowflake_connection(
if query_tag:
queries.append(f"ALTER SESSION SET QUERY_TAG = '{query_tag}';")

# Execute all queries in single batch
with conn.cursor() as cursor:
sql = "\n".join(queries)
_debug_print_query(sql)
cursor.execute(sql, num_statements=0)
# Merge into single SQL batch
sql = "\n".join(queries)
_debug_print_query(sql)
_execute_sql(conn, sql)

return conn

Expand Down
36 changes: 20 additions & 16 deletions src/ds_platform_utils/metaflow/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from snowflake.connector import SnowflakeConnection
from snowflake.connector.pandas_tools import write_pandas

from ds_platform_utils._snowflake.run_query import _execute_sql
from ds_platform_utils.metaflow._consts import NON_PROD_SCHEMA, PROD_SCHEMA
from ds_platform_utils.metaflow.get_snowflake_connection import _debug_print_query, get_snowflake_connection
from ds_platform_utils.metaflow.write_audit_publish import (
Expand Down Expand Up @@ -111,15 +112,14 @@ def publish_pandas( # noqa: PLR0913 (too many arguments)

# set warehouse
if warehouse is not None:
with conn.cursor() as cur:
cur.execute(f"USE WAREHOUSE {warehouse};")
_execute_sql(conn, f"USE WAREHOUSE {warehouse};")

# set query tag for cost tracking in select.dev
# REASON: because write_pandas() doesn't allow modifying the SQL query to add SQL comments in it directly,
# so we set a session query tag instead.
tags = get_select_dev_query_tags()
query_tag_str = json.dumps(tags)
cur.execute(f"ALTER SESSION SET QUERY_TAG = '{query_tag_str}';")
# set query tag for cost tracking in select.dev
# REASON: because write_pandas() doesn't allow modifying the SQL query to add SQL comments in it directly,
# so we set a session query tag instead.
tags = get_select_dev_query_tags()
query_tag_str = json.dumps(tags)
_execute_sql(conn, f"ALTER SESSION SET QUERY_TAG = '{query_tag_str}';")

# https://docs.snowflake.com/en/developer-guide/snowpark/reference/python/latest/snowpark/api/snowflake.snowpark.Session.write_pandas
write_pandas(
Expand Down Expand Up @@ -198,16 +198,20 @@ def query_pandas_from_snowflake(
current.card.append(Markdown(f"```sql\n{query}\n```"))

conn: SnowflakeConnection = get_snowflake_connection(use_utc)
with conn.cursor() as cur:
if warehouse is not None:
cur.execute(f"USE WAREHOUSE {warehouse};")
if warehouse is not None:
_execute_sql(conn, f"USE WAREHOUSE {warehouse};")

cursor_result = _execute_sql(conn, query)
if cursor_result is None:
# No statements to execute, return empty DataFrame
df = pd.DataFrame()
else:
# force_return_table=True -- returns a Pyarrow Table always even if the result is empty
result: pyarrow.Table = cur.execute(query).fetch_arrow_all(force_return_table=True)

result: pyarrow.Table = cursor_result.fetch_arrow_all(force_return_table=True)
df = result.to_pandas()
df.columns = df.columns.str.lower()

current.card.append(Markdown("### Query Result"))
current.card.append(Table.from_dataframe(df.head()))
return df
current.card.append(Markdown("### Query Result"))
current.card.append(Table.from_dataframe(df.head()))

return df
55 changes: 32 additions & 23 deletions src/ds_platform_utils/metaflow/write_audit_publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from metaflow.cards import Artifact, Markdown, Table
from snowflake.connector.cursor import SnowflakeCursor

from ds_platform_utils._snowflake.run_query import _execute_sql
from ds_platform_utils.metaflow.get_snowflake_connection import get_snowflake_connection

if TYPE_CHECKING:
Expand Down Expand Up @@ -97,7 +98,7 @@ def get_select_dev_query_tags() -> Dict[str, str]:
stacklevel=2,
)

def extract(prefix: str, default: str = "unknown") -> str:
def _extract(prefix: str, default: str = "unknown") -> str:
for tag in fetched_tags:
if tag.startswith(prefix + ":"):
return tag.split(":", 1)[1]
Expand All @@ -106,19 +107,19 @@ def extract(prefix: str, default: str = "unknown") -> str:
# most of these will be unknown if no tags are set on the flow
# (most likely for the flow runs which are triggered manually locally)
return {
"app": extract(
"app": _extract(
"ds.domain"
), # first tag after 'app:', is the domain of the flow, fetched from current tags of the flow
"workload_id": extract(
"workload_id": _extract(
"ds.project"
), # second tag after 'workload_id:', is the project of the flow which it belongs to
"flow_name": current.flow_name, # name of the metaflow flow
"flow_name": current.flow_name,
"project": current.project_name, # Project name from the @project decorator, lets us
# identify the flow’s project without relying on user tags (added via --tag).
"step_name": current.step_name, # name of the current step
"run_id": current.run_id, # run_id: unique id of the current run
"user": current.username, # username of user who triggered the run (argo-workflows if its a deployed flow)
"domain": extract("ds.domain"), # business unit (domain) of the flow, same as app
"domain": _extract("ds.domain"), # business unit (domain) of the flow, same as app
"namespace": current.namespace, # namespace of the flow
"perimeter": str(os.environ.get("OB_CURRENT_PERIMETER") or os.environ.get("OBP_PERIMETER")),
"is_production": str(
Expand Down Expand Up @@ -216,7 +217,7 @@ def publish( # noqa: PLR0913, D417

with conn.cursor() as cur:
if warehouse is not None:
cur.execute(f"USE WAREHOUSE {warehouse}")
_execute_sql(conn, f"USE WAREHOUSE {warehouse}")

last_op_was_write = False
for operation in write_audit_publish(
Expand Down Expand Up @@ -334,20 +335,28 @@ def fetch_table_preview(
:param table_name: Table name
:param cursor: Snowflake cursor
"""
cursor.execute(f"""
SELECT *
FROM {database}.{schema}.{table_name}
LIMIT {n_rows};
""")
columns = [col[0] for col in cursor.description]
rows = cursor.fetchall()

# Create header row plus data rows
table_rows = [[Artifact(col) for col in columns]] # Header row
for row in rows:
table_rows.append([Artifact(val) for val in row]) # Data rows

return [
Markdown(f"### Table Preview: ({database}.{schema}.{table_name})"),
Table(table_rows),
]
if cursor is None:
return []
else:
result_cursor = _execute_sql(
cursor.connection,
f"""
SELECT *
FROM {database}.{schema}.{table_name}
LIMIT {n_rows};
""",
)
if result_cursor is None:
return []
columns = [col[0] for col in result_cursor.description]
rows = result_cursor.fetchall()

# Create header row plus data rows
table_rows = [[Artifact(col) for col in columns]] # Header row
for row in rows:
table_rows.append([Artifact(val) for val in row]) # Data rows

return [
Markdown(f"### Table Preview: ({database}.{schema}.{table_name})"),
Table(table_rows),
]
59 changes: 59 additions & 0 deletions tests/unit_tests/snowflake/test__execute_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Functional test for _execute_sql."""

from typing import Generator

import pytest
from snowflake.connector import SnowflakeConnection

from ds_platform_utils._snowflake.run_query import _execute_sql
from ds_platform_utils.metaflow.get_snowflake_connection import get_snowflake_connection


@pytest.fixture(scope="module")
def snowflake_conn() -> Generator[SnowflakeConnection, None, None]:
"""Get a Snowflake connection for testing."""
yield get_snowflake_connection(use_utc=True)


def test_execute_sql_empty_string(snowflake_conn):
"""Empty string returns None."""
cursor = _execute_sql(snowflake_conn, "")
assert cursor is None


def test_execute_sql_whitespace_only(snowflake_conn):
"""Whitespace-only string returns None."""
cursor = _execute_sql(snowflake_conn, " \n\t ")
assert cursor is None


def test_execute_sql_only_semicolons(snowflake_conn):
"""String with only semicolons returns None and raises warning."""
with pytest.warns(UserWarning, match="Empty SQL statement encountered"):
cursor = _execute_sql(snowflake_conn, " ; ;")
assert cursor is None


def test_execute_sql_only_comments(snowflake_conn):
"""String with only comments returns None and raises warning."""
with pytest.warns(UserWarning, match="Empty SQL statement encountered"):
cursor = _execute_sql(snowflake_conn, "/* only comments */")
assert cursor is None


def test_execute_sql_single_statement(snowflake_conn):
"""Single statement returns cursor with expected result."""
cursor = _execute_sql(snowflake_conn, "SELECT 1 AS x;")
assert cursor is not None
rows = cursor.fetchall()
assert len(rows) == 1
assert rows[0][0] == 1


def test_execute_sql_multi_statement(snowflake_conn):
"""Multi-statement returns cursor for last statement only."""
cursor = _execute_sql(snowflake_conn, "SELECT 1 AS x; SELECT 2 AS x;")
assert cursor is not None
rows = cursor.fetchall()
assert len(rows) == 1
assert rows[0][0] == 2 # Last statement result
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.