Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/pyspark/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#
from pyspark.pipelines.api import (
append_flow,
create_auto_cdc_flow,
create_streaming_table,
materialized_view,
table,
Expand All @@ -25,6 +26,7 @@

__all__ = [
"append_flow",
"create_auto_cdc_flow",
"create_streaming_table",
"materialized_view",
"table",
Expand Down
194 changes: 192 additions & 2 deletions python/pyspark/pipelines/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Dict, List, Optional, Union, overload
from typing import Callable, Dict, List, Literal, Optional, Union, overload

from pyspark.errors import PySparkTypeError
from pyspark.pipelines.graph_element_registry import get_active_graph_element_registry
from pyspark.pipelines.type_error_utils import validate_optional_list_of_str_arg
from pyspark.pipelines.flow import Flow, QueryFunction
from pyspark.pipelines.flow import AutoCdcFlow, Flow, QueryFunction
from pyspark.pipelines.source_code_location import (
get_caller_source_code_location,
)
Expand All @@ -29,6 +29,7 @@
TemporaryView,
Sink,
)
from pyspark.sql import Column
from pyspark.sql.types import StructType


Expand Down Expand Up @@ -525,3 +526,192 @@ def create_sink(
comment=None,
)
get_active_graph_element_registry().register_output(sink)


def create_auto_cdc_flow(
target: str,
source: str,
keys: Union[List[str], List[Column]],
sequence_by: Union[str, Column],
apply_as_deletes: Optional[Union[str, Column]] = None,
column_list: Optional[Union[List[str], List[Column]]] = None,
except_column_list: Optional[Union[List[str], List[Column]]] = None,
stored_as_scd_type: Optional[Literal[1, "1"]] = None,
name: Optional[str] = None,
) -> None:
"""
Create an Auto CDC flow into the target table from the Change Data Capture (CDC) source.
Target table must have already been created using create_streaming_table function. Only one
of column_list and except_column_list can be specified.

Example:
create_auto_cdc_flow(
target="target",
source="source",
keys=["key"],
sequence_by="sequence_expr",
column_list=["key", "value"],
)

Note that for keys, sequence_by, column_list, and except_column_list the arguments have to
be column identifiers without qualifiers, e.g. they cannot be col("sourceTable.keyId").

:param target: The name of the target table that receives the Auto CDC flow.
:param source: The name of the CDC source to stream from.
:param keys: The column or combination of columns that uniquely identify a row in the source \
data. This is used to identify which CDC events apply to specific records in the target \
table. These keys also identify records in the target table, e.g., if there exists a record \
for given keys and the CDC source has an UPSERT operation for the same keys, we will update \
the existing record. At least one key must be provided. This should be a list of column \
identifiers without qualifiers, expressed as either Python strings or PySpark Columns.
:param sequence_by: An expression that we use to order the source data. This can be expressed \
Comment thread
AnishMahto marked this conversation as resolved.
as either a SQL expression string or a PySpark Column.
:param apply_as_deletes: Delete condition for the merged operation. This should be a string of \
expression e.g. "operation = 'DELETE'"
:param column_list: Columns that will be included in the output table. This should be a list \
of column identifiers without qualifiers, expressed as either Python strings or PySpark \
Column. Only one of column_list and except_column_list can be specified.
:param except_column_list: Columns that will be excluded in the output table. This should be a \
list of column identifiers without qualifiers, expressed as either Python strings or \
PySpark Column. Only one of column_list and except_column_list can be specified. When this \
is specified, all columns in the dataframe of the target table except those in this list \
will be in the output table.
:param stored_as_scd_type: The SCD type for the target table. Only 1 (or "1") is supported. \
When not specified the server default applies.
:param name: The name of the flow for this create_auto_cdc_flow command. When unspecified \
this will build a "default flow" with name equal to the target name.
"""
from pyspark.sql.connect.functions.builtin import expr as _connect_expr

if type(target) is not str:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "target",
"expected_type": "str",
"arg_type": type(target).__name__,
},
)
if type(source) is not str:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "source",
"expected_type": "str",
"arg_type": type(source).__name__,
},
)
if name is not None and type(name) is not str:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "name",
"expected_type": "str",
"arg_type": type(name).__name__,
},
)

if name is None:
name = target

keys = _normalize_column_list(arg_name="keys", column_list=keys)
column_list = _normalize_optional_column_list(arg_name="column_list", column_list=column_list)
except_column_list = _normalize_optional_column_list(
arg_name="except_column_list", column_list=except_column_list
)

if isinstance(sequence_by, str):
sequence_by = _connect_expr(sequence_by)
elif not isinstance(sequence_by, Column):
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "sequence_by",
"expected_type": "str or Column",
"arg_type": type(sequence_by).__name__,
},
)

if isinstance(apply_as_deletes, str):
apply_as_deletes = _connect_expr(apply_as_deletes)
elif apply_as_deletes is not None and not isinstance(apply_as_deletes, Column):
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "apply_as_deletes",
"expected_type": "str or Column",
"arg_type": type(apply_as_deletes).__name__,
},
)

if stored_as_scd_type is not None and str(stored_as_scd_type) != "1":
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": "stored_as_scd_type",
"expected_type": "Literal[1, '1']",
"arg_type": type(stored_as_scd_type).__name__,
},
)

source_code_location = get_caller_source_code_location(stacklevel=1)

flow = AutoCdcFlow(
name=name,
target=target,
source=source,
keys=keys,
sequence_by=sequence_by,
apply_as_deletes=apply_as_deletes,
column_list=column_list,
except_column_list=except_column_list,
stored_as_scd_type=stored_as_scd_type,
source_code_location=source_code_location,
)

get_active_graph_element_registry().register_auto_cdc_flow(flow)


def _normalize_optional_column_list(
arg_name: str,
column_list: Optional[Union[List[str], List[Column]]],
) -> Optional[List[Column]]:
if column_list is None:
return None
return _normalize_column_list(arg_name=arg_name, column_list=column_list)


def _normalize_column_list(
arg_name: str,
column_list: Union[List[str], List[Column]],
) -> List[Column]:
from pyspark.sql.connect.functions.builtin import col as _connect_col

if not isinstance(column_list, list):
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": arg_name,
"expected_type": "list[str] or list[Column]",
"arg_type": type(column_list).__name__,
},
)

normalized: List[Column] = []

for column in column_list:
if isinstance(column, str):
normalized.append(_connect_col(column))
elif isinstance(column, Column):
normalized.append(column)
else:
raise PySparkTypeError(
errorClass="NOT_EXPECTED_TYPE",
messageParameters={
"arg_name": arg_name,
"expected_type": "list[str] or list[Column]",
"arg_type": type(column).__name__,
},
)

return normalized
34 changes: 33 additions & 1 deletion python/pyspark/pipelines/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# limitations under the License.
#
from dataclasses import dataclass
from typing import Callable, Dict
from typing import Callable, Dict, List, Literal, Optional

from pyspark.sql import DataFrame
from pyspark.sql import Column
from pyspark.pipelines.source_code_location import SourceCodeLocation

QueryFunction = Callable[[], DataFrame]
Expand All @@ -41,3 +42,34 @@ class Flow:
spark_conf: Dict[str, str]
source_code_location: SourceCodeLocation
func: QueryFunction


@dataclass(frozen=True)
class AutoCdcFlow:
"""Definition of an Auto CDC flow in a pipeline dataflow graph.

An Auto CDC flow applies Change Data Capture (CDC) events from a source to a target
streaming table.

:param name: Optional name of the flow. When None, defaults to the target name.
:param target: The name of the target streaming table.
:param source: The name of the CDC source to stream from.
:param keys: Column(s) that uniquely identify a row in source and target data.
:param sequence_by: Expression used to order the source data.
:param apply_as_deletes: Optional delete condition for the merged operation.
:param column_list: Optional columns to include in the output table.
:param except_column_list: Optional columns to exclude from the output table.
:param stored_as_scd_type: Optional SCD type for the target table. Only 1 is supported.
:param source_code_location: The location of the source code that created this flow.
"""

name: Optional[str]
target: str
source: str
keys: List[Column]
sequence_by: Column
apply_as_deletes: Optional[Column]
column_list: Optional[List[Column]]
except_column_list: Optional[List[Column]]
stored_as_scd_type: Optional[Literal[1, "1"]]
source_code_location: SourceCodeLocation
6 changes: 5 additions & 1 deletion python/pyspark/pipelines/graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path

from pyspark.pipelines.output import Output
from pyspark.pipelines.flow import Flow
from pyspark.pipelines.flow import AutoCdcFlow, Flow
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Generator, Optional
Expand All @@ -42,6 +42,10 @@ def register_output(self, output: Output) -> None:
def register_flow(self, flow: Flow) -> None:
"""Add the given flow to the registry."""

@abstractmethod
def register_auto_cdc_flow(self, flow: AutoCdcFlow) -> None:
"""Add the given Auto CDC flow to the registry."""

@abstractmethod
def register_sql(self, sql_text: str, file_path: Path) -> None:
"""Register a string containing SQL statements the dataflow graph.
Expand Down
40 changes: 37 additions & 3 deletions python/pyspark/pipelines/spark_connect_graph_element_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pathlib import Path

from pyspark.errors import PySparkTypeError
from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, Column
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.pipelines.output import (
Output,
Expand All @@ -27,12 +27,12 @@
StreamingTable,
TemporaryView,
)
from pyspark.pipelines.flow import Flow
from pyspark.pipelines.flow import AutoCdcFlow, Flow
from pyspark.pipelines.graph_element_registry import GraphElementRegistry
from pyspark.pipelines.source_code_location import SourceCodeLocation
from pyspark.sql.connect.types import pyspark_types_to_proto_types
from pyspark.sql.types import StructType
from typing import Any, cast
from typing import Any, List, Optional, cast
import pyspark.sql.connect.proto as pb2
from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context

Expand Down Expand Up @@ -133,6 +133,40 @@ def register_flow(self, flow: Flow) -> None:
command.pipeline_command.define_flow.CopyFrom(inner_command)
self._client.execute_command(command)

def register_auto_cdc_flow(self, flow: AutoCdcFlow) -> None:
from pyspark.sql.connect.column import Column as ConnectColumn

def to_plan(col: Column) -> Any:
return cast(ConnectColumn, col).to_plan(self._client)

def to_plans(cols: Optional[List[Column]]) -> list:
return [] if cols is None else [to_plan(c) for c in cols]

auto_cdc_details = pb2.PipelineCommand.DefineFlow.AutoCdcFlowDetails(
source=flow.source,
keys=to_plans(flow.keys),
sequence_by=to_plan(flow.sequence_by),
column_list=to_plans(flow.column_list),
except_column_list=to_plans(flow.except_column_list),
)
if flow.stored_as_scd_type is not None:
auto_cdc_details.stored_as_scd_type = pb2.PipelineCommand.DefineFlow.SCDType.SCD_TYPE_1
if flow.apply_as_deletes is not None:
auto_cdc_details.apply_as_deletes.CopyFrom(to_plan(flow.apply_as_deletes))

inner_command = pb2.PipelineCommand.DefineFlow(
dataflow_graph_id=self._dataflow_graph_id,
flow_name=flow.name,
target_dataset_name=flow.target,
auto_cdc_flow_details=auto_cdc_details,
sql_conf={},
source_code_location=source_code_location_to_proto(flow.source_code_location),
)

command = pb2.Command()
command.pipeline_command.define_flow.CopyFrom(inner_command)
self._client.execute_command(command)

def register_sql(self, sql_text: str, file_path: Path) -> None:
inner_command = pb2.PipelineCommand.DefineSqlGraphElements(
dataflow_graph_id=self._dataflow_graph_id,
Expand Down
Loading