Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 49 additions & 26 deletions tests/test_repository.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-
import unittest

import pytest
from py4j.protocol import Py4JError
from pyspark.sql import Row

from pydeequ.analyzers import *
from pydeequ.checks import *
from pydeequ.repository import *
from pydeequ.verification import *
from pydeequ.analyzers import AnalyzerContext, AnalysisRunner, ApproxCountDistinct
from pydeequ.checks import Check, CheckLevel
from pydeequ.repository import FileSystemMetricsRepository, InMemoryMetricsRepository, ResultKey
from pydeequ.verification import VerificationResult, VerificationSuite
from tests.conftest import setup_pyspark


Expand All @@ -18,7 +18,9 @@ def setUpClass(cls):
cls.AnalysisRunner = AnalysisRunner(cls.spark)
cls.VerificationSuite = VerificationSuite(cls.spark)
cls.sc = cls.spark.sparkContext
cls.df = cls.sc.parallelize([Row(a="foo", b=1, c=5), Row(a="bar", b=2, c=6), Row(a="baz", b=3, c=None)]).toDF()
cls.df = cls.sc.parallelize(
[Row(a="foo", b=1, c=5), Row(a="bar", b=2, c=6), Row(a="baz", b=3, c=None)]
).toDF()

@classmethod
def tearDownClass(cls):
Expand Down Expand Up @@ -121,12 +123,16 @@ def test_verifications_FSmetrep(self):
)

# TEST: Check JSON for tags
result_metrep_json = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsJson()
result_metrep_json = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsJson()
)

print(result_metrep_json[0]["tag"], key_tags["tag"])
self.assertEqual(result_metrep_json[0]["tag"], key_tags["tag"])

result_metrep = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
result_metrep = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
)

df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
print(df.collect())
Expand All @@ -146,7 +152,9 @@ def test_verifications_FSmetrep_noTags_noFile(self):
)

# TEST: Check DF parity
result_metrep = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
result_metrep = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
)

df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
print(df.collect())
Expand Down Expand Up @@ -243,12 +251,16 @@ def test_verifications_IMmetrep(self):
)

# TEST: Check JSON for tags
result_metrep_json = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsJson()
result_metrep_json = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsJson()
)

print(result_metrep_json[0]["tag"], key_tags["tag"])
self.assertEqual(result_metrep_json[0]["tag"], key_tags["tag"])

result_metrep = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
result_metrep = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
)

df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
print(df.collect())
Expand All @@ -267,37 +279,43 @@ def test_verifications_IMmetrep_noTags_noFile(self):
)

# TEST: Check DF parity
result_metrep = repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
result_metrep = (
repository.load().before(ResultKey.current_milli_time()).getSuccessMetricsAsDataFrame()
)

df = VerificationResult.checkResultsAsDataFrame(self.spark, result)
print(df.collect())
print(result_metrep.collect())

@pytest.mark.xfail(reason="@unittest.expectedFailure")
def test_fail_no_useRepository(self):
"""This test should fail because it doesn't call useRepository() before saveOrAppendResult()"""
"""This run fails because it doesn't call useRepository() before saveOrAppendResult()."""
metrics_file = FileSystemMetricsRepository.helper_metrics_file(self.spark, "metrics.json")
print(f"metrics filepath: {metrics_file}")
key_tags = {"tag": "FS metrep analyzers -- FAIL"}
resultKey = ResultKey(self.spark, ResultKey.current_milli_time(), key_tags)

# MISSING useRepository()
result = (
self.AnalysisRunner.onData(self.df)
.addAnalyzer(ApproxCountDistinct("b"))
.saveOrAppendResult(resultKey)
.run()
with self.assertRaises(Py4JError) as err:
_ = (
self.AnalysisRunner.onData(self.df)
.addAnalyzer(ApproxCountDistinct("b"))
.saveOrAppendResult(resultKey)
.run()
)

self.assertIn(
"Method saveOrAppendResult([class com.amazon.deequ.repository.ResultKey]) does not exist",
str(err.exception),
)

@pytest.mark.xfail(reason="@unittest.expectedFailure")
def test_fail_no_load(self):
"""This test should fail because we do not load() for the repository reading"""
"""This run fails because we do not load() for the repository reading."""
metrics_file = FileSystemMetricsRepository.helper_metrics_file(self.spark, "metrics.json")
print(f"metrics filepath: {metrics_file}")
repository = FileSystemMetricsRepository(self.spark, metrics_file)
key_tags = {"tag": "FS metrep analyzers"}
resultKey = ResultKey(self.spark, ResultKey.current_milli_time(), key_tags)
result = (
_ = (
self.AnalysisRunner.onData(self.df)
.addAnalyzer(ApproxCountDistinct("b"))
.useRepository(repository)
Expand All @@ -306,8 +324,13 @@ def test_fail_no_load(self):
)

# MISSING: repository.load()
result_metrep_json = (
repository.before(ResultKey.current_milli_time())
.forAnalyzers([ApproxCountDistinct("b")])
.getSuccessMetricsAsJson()
with self.assertRaises(AttributeError) as err:
_ = (
repository.before(ResultKey.current_milli_time())
.forAnalyzers([ApproxCountDistinct("b")])
.getSuccessMetricsAsJson()
)

self.assertEqual(
"'FileSystemMetricsRepository' object has no attribute 'RepositoryLoader'", str(err.exception)
)