Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ gcs = [
]
git = [
"apache-airflow-providers-ssh>=3.0.0",
"dulwich>=1.0.0",
# dulwich>=1.2.0 no longer includes a paramiko vendor module we need for SSH connections.
# Until we can provide our own, exclude newer versions.
"dulwich>=1.0.0,<1.2.0",
]
postgres = [
"dbt-postgres>=1.8.0,<2.0.0",
Expand All @@ -78,7 +80,7 @@ dev = [
"coverage[toml]>=7.2",
# docutils 0.21 causes an error with poetry (https://github.com/python-poetry/poetry/issues/9293)
"docutils!=0.21",
"dulwich>=0.21,!=0.21.6",
"dulwich>=1.0.0,<1.2.0",
"freezegun>=1.1.0",
"mock-gcp==0.2.0",
"moto>=4.0.3",
Expand Down
62 changes: 47 additions & 15 deletions tests/dags/test_dbt_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
else:
from airflow.providers.common.compat.sdk import DAG

from airflow.serialization.serialized_objects import SerializedDAG

try:
from airflow.serialization.serialized_objects import DagSerialization
except ImportError:
DagSerialization = SerializedDAG

DATA_INTERVAL_START = pendulum.datetime(2022, 1, 1, tz="UTC")
DATA_INTERVAL_END = DATA_INTERVAL_START + dt.timedelta(hours=1)

Expand All @@ -57,12 +64,12 @@ def sync_dag_to_db(

def _write_dag(dag: DAG) -> SerializedDAG:
if not SerializedDagModel.has_dag(dag.dag_id):
data = SerializedDAG.to_dict(dag)
data = DagSerialization.to_dict(dag)
SerializedDagModel.write_dag(
LazyDeserializedDAG(data=data), bundle_name, session=session
)
session.flush()
return SerializedDAG.from_dict(data)
return DagSerialization.from_dict(data)

SerializedDAG.bulk_write_to_db(bundle_name, None, [dag], session=session)
_ = _write_dag(dag)
Expand Down Expand Up @@ -106,6 +113,20 @@ def _create_dagrun(
)


def _run_task_instance(ti):
if AIRFLOW_V_3_1_PLUS:
return

if AIRFLOW_V_3_0:
# Airflow 3.0.6's TaskInstance runner fails in-process with a 422.
if isinstance(ti.task, DbtBaseOperator):
ti.task.execute({"ti": ti})
ti.state = TaskInstanceState.SUCCESS
return

ti.run(ignore_ti_state=True)


@pytest.fixture(scope="session")
def dagbag():
"""An Airflow DagBag."""
Expand Down Expand Up @@ -240,7 +261,7 @@ def test_dbt_operators_in_dag(
ti = dagrun.get_task_instance(task_id=task_id)
ti.task = basic_dag.get_task(task_id=task_id)

ti.run(ignore_ti_state=True)
_run_task_instance(ti)

assert ti.state == TaskInstanceState.SUCCESS

Expand Down Expand Up @@ -345,6 +366,12 @@ def test_dbt_operators_in_taskflow_dag(
else:
dag = taskflow_dag

if AIRFLOW_V_3_0:
for task in dag.tasks:
if isinstance(task, DbtBaseOperator):
task.profiles_dir = str(profiles_file.parent)
task.project_dir = str(dbt_project_file.parent)

dagrun = _create_dagrun(
dag,
state=DagRunState.RUNNING,
Expand All @@ -364,7 +391,7 @@ def test_dbt_operators_in_taskflow_dag(
ti = dagrun.get_task_instance(task_id=task_id)
ti.task = dag.get_task(task_id=task_id)

ti.run(ignore_ti_state=True)
_run_task_instance(ti)

assert ti.state == TaskInstanceState.SUCCESS
assert ti.task.retries == dag.default_args["retries"]
Expand All @@ -376,8 +403,9 @@ def test_dbt_operators_in_taskflow_dag(
assert failure_callback == dag.default_args["on_failure_callback"]

if isinstance(ti.task, DbtBaseOperator):
assert ti.task.profiles_dir == str(profiles_file.parent)
assert ti.task.project_dir == str(dbt_project_file.parent)
if not AIRFLOW_V_3_1_PLUS:
assert ti.task.profiles_dir == str(profiles_file.parent)
assert ti.task.project_dir == str(dbt_project_file.parent)

results = ti.xcom_pull(
task_ids=task_id,
Expand Down Expand Up @@ -493,7 +521,7 @@ def test_dbt_operators_in_connection_dag(
ti = dagrun.get_task_instance(task_id=task_id)
ti.task = target_connection_dag.get_task(task_id=task_id)

ti.run(ignore_ti_state=True)
_run_task_instance(ti)

assert ti.state == TaskInstanceState.SUCCESS

Expand Down Expand Up @@ -568,7 +596,7 @@ def test_example_basic_dag(
ti = dagrun.get_task_instance(task_id="dbt_run_hourly")
ti.task = dbt_run

ti.run(ignore_ti_state=True)
_run_task_instance(ti)

assert ti.state == TaskInstanceState.SUCCESS

Expand Down Expand Up @@ -612,6 +640,9 @@ def test_example_dbt_project_in_github_dag(
if AIRFLOW_V_3_0:
dag = DAG.from_sdk_dag(dag) # type: ignore

for task_id in ("dbt_seed", "dbt_run", "dbt_test"):
dag.get_task(task_id=task_id).dbt_conn_id = connection

dagrun = _create_dagrun(
dag,
state=DagRunState.RUNNING,
Expand All @@ -626,7 +657,7 @@ def test_example_dbt_project_in_github_dag(
ti.task = dag.get_task(task_id=task_id)
ti.task.dbt_conn_id = connection

ti.run(ignore_ti_state=True)
_run_task_instance(ti)

assert ti.state == TaskInstanceState.SUCCESS

Expand Down Expand Up @@ -669,6 +700,12 @@ def test_example_complete_dbt_workflow_dag(
if AIRFLOW_V_3_0:
dag = DAG.from_sdk_dag(dag) # type: ignore

for task in dag.tasks:
task.project_dir = dbt_project_file.parent
task.profiles_dir = profiles_file.parent
task.target = "test"
task.profile = "default"

dagrun = _create_dagrun(
dag,
state=DagRunState.RUNNING,
Expand All @@ -679,15 +716,10 @@ def test_example_complete_dbt_workflow_dag(
)

for task in dag.tasks:
task.project_dir = dbt_project_file.parent
task.profiles_dir = profiles_file.parent
task.target = "test"
task.profile = "default"

ti = dagrun.get_task_instance(task_id=task.task_id)
ti.task = task

ti.run(ignore_ti_state=True)
_run_task_instance(ti)

assert ti.state == TaskInstanceState.SUCCESS

Expand Down
8 changes: 7 additions & 1 deletion tests/hooks/test_git_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def repo(repo_dir, dbt_project_file, test_files, profiles_file, repo_branch):
repo.get_worktree().stage(f"{test_file.parent.name}/{test_file.name}")

repo.get_worktree().commit(
b"Test first commit", committer=b"Test user <test@user.com>"
b"Test first commit", committer=b"Test user <test@user.com>", sign=False
)

yield repo
Expand Down Expand Up @@ -439,6 +439,12 @@ def upload_only_target(u: URL):
server_address, server_port = git_server
destination = URL(f"git://{server_address}:{server_port}/{repo_name}")

repo = Repo(repo_dir)
config = repo.get_config()
config.set((b"commit",), b"gpgsign", False)
config.write_to_path()
repo.close()

fs_hook.upload_dbt_project(str(repo_dir), destination)

new_repo_path = tmp_path / "new_repo"
Expand Down
Loading
Loading