Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pydeequ/configs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-
import logging
from functools import lru_cache
import os
import re

import pyspark

SPARK_TO_DEEQU_COORD_MAPPING = {
"3.5": "com.amazon.deequ:deequ:2.0.7-spark-3.5",
Expand All @@ -22,7 +23,12 @@ def _extract_major_minor_versions(full_version: str):
@lru_cache(maxsize=None)
def _get_spark_version() -> str:
try:
spark_version = os.environ["SPARK_VERSION"]
spark_version = os.getenv("SPARK_VERSION")
if not spark_version:
spark_version = str(pyspark.__version__)
logging.info(
f"SPARK_VERSION environment variable is not set, using Spark version from PySpark {spark_version} for Deequ Maven jars"
)
except KeyError:
raise RuntimeError(f"SPARK_VERSION environment variable is required. Supported values are: {SPARK_TO_DEEQU_COORD_MAPPING.keys()}")

Expand Down
31 changes: 30 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
import os
from unittest import mock

import pyspark
import pytest
from pydeequ.configs import _extract_major_minor_versions

from pydeequ.configs import _extract_major_minor_versions, _get_spark_version


@pytest.fixture
def mock_env(monkeypatch):
with mock.patch.dict(os.environ, clear=True):
monkeypatch.delenv("SPARK_VERSION", raising=False)
yield


@pytest.mark.parametrize(
Expand All @@ -13,3 +25,20 @@
)
def test_extract_major_minor_versions(full_version, major_minor_version):
assert _extract_major_minor_versions(full_version) == major_minor_version


@pytest.mark.parametrize(
"spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.1"), ("3.10.3", "3.10"), ("3.10", "3.10")]
)
def test__get_spark_version_without_cache(spark_version, expected, mock_env):
with mock.patch.object(pyspark, "__version__", spark_version):
assert _get_spark_version() == expected
_get_spark_version.cache_clear()


@pytest.mark.parametrize(
"spark_version, expected", [("3.2.1", "3.2"), ("3.1", "3.2"), ("3.10.3", "3.2"), ("3.10", "3.2")]
)
def test__get_spark_version_with_cache(spark_version, expected, mock_env):
with mock.patch.object(pyspark, "__version__", spark_version):
assert _get_spark_version() == expected