Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 95 additions & 4 deletions python/pfs/utils/database/db.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import csv
import logging
import re
import time
from contextlib import contextmanager
from io import StringIO
from threading import RLock
from typing import Any, Mapping, Optional, Union

import numpy as np
import pandas as pd
from sqlalchemy import MetaData, Table, create_engine, text
from sqlalchemy import MetaData, Table, create_engine, exc, text
from sqlalchemy.engine import Connection, Engine
from sqlalchemy import exc

_DB_ENGINES: dict[str, Engine] = {}
_DB_ENGINES_LOCK = RLock()
Expand Down Expand Up @@ -579,10 +581,16 @@ def insert_dataframe(
df: pd.DataFrame,
index: bool = False,
chunksize: int | None = None,
use_copy: bool = True,
**kwargs: Any,
) -> int | None:
"""Insert into a table via a dataframe.

By default, this method uses PostgreSQL's ``COPY`` command for bulk inserts
via the ``psql_insert_copy`` method, which is generally faster for large
inserts. If ``use_copy`` is set to ``False``, it falls back to the ``multi``
method provided by pandas.

Parameters
----------
table : str
Expand All @@ -596,6 +604,10 @@ def insert_dataframe(
can be sent in a single query, so the effective maximum chunk size
depends on the number of columns in the DataFrame. When ``None``, the
chunk size will be determined by the number of columns, i.e. ``65535 // n_columns``.
use_copy : bool, default True
If ``True``, use PostgreSQL's ``COPY`` command for bulk inserts via
the ``psql_insert_copy`` method. This is generally faster for large
inserts. If ``False``, use the ``multi`` method provided by pandas.

Returns
-------
Expand Down Expand Up @@ -647,20 +659,26 @@ def insert_dataframe(
self.logger.error(f"Failed to reflect table '{table}': {e}")
raise

if use_copy:
self.logger.debug(f"Scrubbing DataFrame for safe COPY insertion...")
df = scrub_dataframe_for_copy(df)

try:
self.logger.info(f"Starting insert of {len(df)} rows into table '{table}'...")
t0 = time.perf_counter()
with self.connection() as conn:
inserted_rows = df.to_sql(
name=table,
con=conn,
if_exists="append",
index=index,
chunksize=chunksize,
method="multi",
method=_psql_insert_copy if use_copy else "multi",
dtype=dtype_map,
)
t1 = time.perf_counter()

self.logger.info(f"Successfully inserted {inserted_rows} rows into '{table}'")
self.logger.info(f"Successfully inserted {inserted_rows} rows into '{table}' in {t1 - t0:.2f} seconds.")
return inserted_rows
except Exception as e:
db_error = str(getattr(e, "orig", e))
Expand Down Expand Up @@ -703,3 +721,76 @@ def insert_kw(self, table: str, **kwargs: Any) -> None:
conn.execute(ins)

return None


def _psql_insert_copy(table, conn, keys, data_iter):
"""Insert data into PostgreSQL using the COPY command."""
csv_buffer = StringIO()
writer = csv.writer(
csv_buffer,
delimiter=",",
quotechar='"',
escapechar="\\",
lineterminator="\n",
quoting=csv.QUOTE_MINIMAL
)
writer.writerows(data_iter)
csv_buffer.seek(0)

raw_conn = conn.connection

with raw_conn.cursor() as cur:
columns = ', '.join(f'"{k}"' for k in keys)
sql = f"""
COPY {table.name} ({columns})
FROM STDIN
WITH (
FORMAT CSV,
NULL '',
QUOTE '"',
DELIMITER ',',
HEADER FALSE
)
"""

try:
with cur.copy(sql) as copy:
copy.write(csv_buffer.read())

return cur.rowcount
except Exception as e:
logging.error(
"Error during COPY into table %s with columns %s. SQL: %s",
getattr(table, "name", table),
list(keys),
sql,
)
raise


def scrub_dataframe_for_copy(df):
"""
Scans a DataFrame and:
1. Converts 'float' columns that contain only integers (and NaNs) to 'Int64'.
2. Ensures real 'None' objects are used instead of NaNs for safety.
"""
df_clean = df.copy()

for col in df_clean.select_dtypes(include=['float']):

# Check if all non-null values in this column are actually integers
# logical check: (value % 1) should be 0 for all values.
# e.g. 99.0 % 1 == 0 (True), but 99.5 % 1 == 0.5 (False)
series = df_clean[col].dropna()

if len(series) > 0 and (series % 1 == 0).all():
# Safe to convert!
# This turns 99.0 -> 99 and NaN -> <NA>
df_clean[col] = df_clean[col].astype('Int64')

# Optional: Scrub string columns to ensure empty strings are None
# (Useful because COPY can get confused by empty strings vs NULLs)
for col in df_clean.select_dtypes(include=['object']):
df_clean[col] = df_clean[col].replace({'': None, np.nan: None})

return df_clean