Skip to content
This repository was archived by the owner on May 22, 2025. It is now read-only.

Commit fece91a

Browse files
Add export_options to export_to_file operator
The current implementation of the export_to_file operator doesn't provide any mechanism for overriding or customizing the behavior of the file output. Instead the operator simply calls one of a number of Pandas.DataFrame.to_* functions with default functions that SDK users are locked in to. This modification allows for the provision of a configuration dictionary that enables the customizing or overriding of the file write behavior by passing the parameters to the FileType.create_from_dataframe implementation.
1 parent 33ca675 commit fece91a

File tree

9 files changed

+50
-16
lines changed

9 files changed

+50
-16
lines changed

python-sdk/src/astro/files/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,20 @@ def is_pattern(self) -> bool:
114114
"""
115115
return not pathlib.PosixPath(self.path).suffix
116116

117-
def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True) -> None:
117+
def create_from_dataframe(self, df: pd.DataFrame, store_as_dataframe: bool = True, export_options: dict | None = None) -> None:
118118
"""Create a file in the desired location using the values of a dataframe.
119119
120120
:param store_as_dataframe: Whether the data should later be deserialized as a dataframe or as a file containing
121121
delimited data (e.g. csv, parquet, etc.).
122122
:param df: pandas dataframe
123+
:param export_options: additional arguments to pass to the underlying write functionality
123124
"""
124125

125126
self.is_dataframe = store_as_dataframe
127+
opts = export_options or {}
126128

127129
with self.location.get_stream() as stream:
128-
self.type.create_from_dataframe(stream=stream, df=df)
130+
self.type.create_from_dataframe(stream=stream, df=df, **opts)
129131

130132
@property
131133
def openlineage_dataset_namespace(self) -> str:

python-sdk/src/astro/files/types/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ def export_to_dataframe(self, stream, **kwargs) -> pd.DataFrame:
2727
raise NotImplementedError
2828

2929
@abstractmethod
30-
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None:
30+
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None:
3131
"""Write file to one of the supported locations
3232
3333
:param df: pandas dataframe
3434
:param stream: file stream object
35+
:param kwargs: additional arguments to pass to the underlying write functionality
3536
"""
3637
raise NotImplementedError
3738

python-sdk/src/astro/files/types/csv.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ def export_to_dataframe(
3838
return PandasDataframe.from_pandas_df(df)
3939

4040
# We need skipcq because it's a method overloading so we don't want to make it a static method
41-
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
41+
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
4242
"""Write csv file to one of the supported locations
4343
4444
:param df: pandas dataframe
4545
:param stream: file stream object
46+
:param kwargs: additional arguments to pass to the pandas `to_csv` function
4647
"""
47-
df.to_csv(stream, index=False)
48+
49+
df.to_csv(stream, **dict(index=False, **kwargs))
4850

4951
@property
5052
def name(self):

python-sdk/src/astro/files/types/excel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,11 @@ def export_to_dataframe(
3737
return PandasDataframe.from_pandas_df(df)
3838

3939
# We need skipcq because it's a method overloading so we don't want to make it a static method
40-
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
40+
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
4141
"""Write Excel file to one of the supported locations
4242
4343
:param df: pandas dataframe
4444
:param stream: file stream object
45+
:param kwargs: additional arguments to pass to the pandas `to_excel` function
4546
"""
46-
df.to_excel(stream, index=False)
47+
df.to_excel(stream, **dict(index=False, **kwargs))

python-sdk/src/astro/files/types/json.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@ def export_to_dataframe(
4242
return PandasDataframe.from_pandas_df(df)
4343

4444
# We need skipcq because it's a method overloading so we don't want to make it a static method
45-
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
45+
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
4646
"""Write json file to one of the supported locations
4747
4848
:param df: pandas dataframe
4949
:param stream: file stream object
50+
:param kwargs: additional arguments to pass to the pandas `to_json` function
5051
"""
51-
df.to_json(stream, orient="records")
52+
df.to_json(stream, **dict(orient="records", **kwargs))
5253

5354
@property
5455
def name(self):

python-sdk/src/astro/files/types/ndjson.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ def export_to_dataframe(
3939
return PandasDataframe.from_pandas_df(df)
4040

4141
# We need skipcq because it's a method overloading so we don't want to make it a static method
42-
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
42+
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
4343
"""Write ndjson file to one of the supported locations
4444
4545
:param df: pandas dataframe
4646
:param stream: file stream object
47+
:param kwargs: additional arguments to pass to the pandas `to_json` function
4748
"""
48-
df.to_json(stream, orient="records", lines=True)
49+
df.to_json(stream, **dict(orient="records", lines=True, **kwargs))
4950

5051
@property
5152
def name(self):

python-sdk/src/astro/files/types/parquet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,14 @@ def _convert_remote_file_to_byte_stream(stream) -> io.IOBase:
5757
return remote_obj_buffer
5858

5959
# We need skipcq because it's a method overloading so we don't want to make it a static method
60-
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper) -> None: # skipcq PYL-R0201
60+
def create_from_dataframe(self, df: pd.DataFrame, stream: io.TextIOWrapper, **kwargs) -> None: # skipcq PYL-R0201
6161
"""Write parquet file to one of the supported locations
6262
6363
:param df: pandas dataframe
6464
:param stream: file stream object
65+
:param kwargs: additional arguments to pass to the pandas `to_parquet` method
6566
"""
66-
df.to_parquet(stream)
67+
df.to_parquet(stream, **kwargs)
6768

6869
@property
6970
def name(self):

python-sdk/src/astro/sql/operators/export_to_file.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,23 @@ class ExportToFileOperator(AstroSQLBaseOperator):
2121
:param input_data: Table to convert to file
2222
:param output_file: File object containing the path to the file and connection id.
2323
:param if_exists: Overwrite file if exists. Default False.
24+
:param export_options: Additional options to pass to the file export functions.
2425
"""
2526

26-
template_fields = ("input_data", "output_file")
27+
template_fields = ("input_data", "output_file", "export_options")
2728

2829
def __init__(
2930
self,
3031
input_data: BaseTable | pd.DataFrame,
3132
output_file: File,
3233
if_exists: ExportExistsStrategy = "exception",
34+
export_options: dict | None = None,
3335
**kwargs,
3436
) -> None:
3537
self.output_file = output_file
3638
self.input_data = input_data
3739
self.if_exists = if_exists
40+
self.export_options = export_options or {}
3841
self.kwargs = kwargs
3942
datasets = {"output_datasets": self.output_file}
4043
if isinstance(input_data, Table):
@@ -57,7 +60,7 @@ def execute(self, context: Context) -> File: # skipcq PYL-W0613
5760
raise ValueError(f"Expected input_table to be Table or dataframe. Got {type(self.input_data)}")
5861
# Write file if overwrite == True or if file doesn't exist.
5962
if self.if_exists == "replace" or not self.output_file.exists():
60-
self.output_file.create_from_dataframe(df, store_as_dataframe=False)
63+
self.output_file.create_from_dataframe(df, store_as_dataframe=False, export_options=self.export_options)
6164
return self.output_file
6265
else:
6366
raise FileExistsError(f"{self.output_file.path} file already exists.")
@@ -144,7 +147,8 @@ def export_to_file(
144147
output_file: File,
145148
if_exists: ExportExistsStrategy = "exception",
146149
task_id: str | None = None,
147-
**kwargs: Any,
150+
export_options: dict | None = None,
151+
**kwargs,
148152
) -> XComArg:
149153
"""Convert ExportToFileOperator into a function. Returns XComArg.
150154
@@ -170,6 +174,7 @@ def export_to_file(
170174
:param input_data: Input table / dataframe
171175
:param if_exists: Overwrite file if exists. Default "exception"
172176
:param task_id: task id, optional
177+
:param export_options: Additional options to pass to the file export functions.
173178
"""
174179

175180
task_id = task_id or get_unique_task_id("export_to_file")
@@ -179,5 +184,6 @@ def export_to_file(
179184
output_file=output_file,
180185
input_data=input_data,
181186
if_exists=if_exists,
187+
export_options=export_options,
182188
**kwargs,
183189
).output

python-sdk/tests/sql/operators/test_export_file.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,25 @@ def make_df():
3737
assert df.equals(pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]}))
3838

3939

40+
def test_save_dataframe_to_local_with_options(sample_dag):
41+
@aql.dataframe
42+
def make_df():
43+
return pd.DataFrame(data={"col1": [1, 2], "col2": [3, 4]})
44+
45+
with sample_dag:
46+
df = make_df()
47+
aql.export_to_file(
48+
input_data=df,
49+
output_file=File(path="/tmp/saved_df.csv"),
50+
if_exists="replace",
51+
export_options={"header": None},
52+
)
53+
test_utils.run_dag(sample_dag)
54+
55+
df = pd.read_csv("/tmp/saved_df.csv")
56+
assert df.equals(pd.DataFrame(data={"0": [1, 2], "1": [3, 4]}))
57+
58+
4059
@pytest.mark.parametrize("database_table_fixture", [{"database": Database.SQLITE}], indirect=True)
4160
def test_save_temp_table_to_local(sample_dag, database_table_fixture):
4261
_, test_table = database_table_fixture

0 commit comments

Comments
 (0)