Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions sdv/cag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sdv.cag.one_hot_encoding import OneHotEncoding
from sdv.cag.programmable_constraint import (
ProgrammableConstraint,
SingleTableProgrammableConstraint,
)

__all__ = (
Expand All @@ -17,5 +16,4 @@
'Range',
'OneHotEncoding',
'ProgrammableConstraint',
'SingleTableProgrammableConstraint',
)
55 changes: 4 additions & 51 deletions sdv/cag/programmable_constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@

from copy import deepcopy

from sdv.cag._errors import ConstraintNotMetError
from sdv.cag.base import BaseConstraint


class ProgrammableConstraint:
"""Base class to create programmable constraints."""

_is_single_table = False
_is_single_table = True

def validate(self, metadata):
"""Validates that the metadata is compatible with the constraint and its parameters.
Expand Down Expand Up @@ -125,12 +124,6 @@ def fix_data(self, synthetic_data):
return synthetic_data


class SingleTableProgrammableConstraint(ProgrammableConstraint):
"""Single table variant of the base programmable constraint class."""

_is_single_table = True


class ProgrammableConstraintHarness(BaseConstraint):
"""Constraint interface class for programmable constraints."""

Expand All @@ -141,63 +134,23 @@ def __init__(self, programmable_constraint):
self._is_single_table = self.programmable_constraint._is_single_table

def _validate_constraint_with_metadata(self, metadata):
if self.programmable_constraint._is_single_table and len(metadata.tables) != 1:
if getattr(self.programmable_constraint, 'table_name', None) is None:
raise ConstraintNotMetError(
'SingleTableProgrammableConstraint cannot be used with multi-table metadata '
'if the `table_name` attribute has not been set. Please set the table name '
'attribute to the target table, or use the ProgrammableContraint '
'base class instead.'
)

self.programmable_constraint.validate(metadata)

def _validate_constraint_with_data(self, data, metadata):
if self._is_single_table:
data = data[self._get_single_table_name(metadata)]

self.programmable_constraint.validate_input_data(data)

def _get_updated_metadata(self, metadata):
metadata = deepcopy(metadata)
return self.programmable_constraint.get_updated_metadata(metadata)

def _fit(self, data, metadata):
if self._is_single_table:
data = data[self._table_name]

self.programmable_constraint.fit(data, metadata)

def _transform(self, data):
if self._is_single_table:
data = data[self._table_name]

transformed = self.programmable_constraint.transform(data)

if self._is_single_table:
return {self._table_name: transformed}

return transformed
return self.programmable_constraint.transform(data)

def _reverse_transform(self, data):
if self._is_single_table:
data = data[self._table_name]

reverse_transformed = self.programmable_constraint.reverse_transform(data)

if self._is_single_table:
return {self._table_name: reverse_transformed}

return reverse_transformed
return self.programmable_constraint.reverse_transform(data)

def _is_valid(self, data, metadata):
if self._is_single_table:
table_name = self._get_single_table_name(metadata)
data = data[table_name]

is_valid = self.programmable_constraint.is_valid(data)

if self._is_single_table:
return {table_name: is_valid}

return is_valid
return self.programmable_constraint.is_valid(data)
4 changes: 2 additions & 2 deletions sdv/sequential/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,13 @@ def get_parameters(self):
def add_constraints(self, constraints):
"""Add constraints to the synthesizer.

For PARSynthesizers allow SingleTableProgrammableConstraints and built-in constraints
For PARSynthesizers allow ProgrammableConstraints and built-in constraints
that follow these rules:

1) All constraints must be either for all contextual columns or non-contextual column.
No mixing constraints that cover both contextual and non-contextual columns
2) No overlapping constraints (there are no constraints that act on the same column)
3) Any custom constraint is allowed, as long as it is a SingleTableProgrammableConstraint
3) Any custom constraint is allowed, as long as it is a single table constraint

Args:
constraints (list):
Expand Down
112 changes: 38 additions & 74 deletions tests/integration/cag/test_programmable_constraint.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,70 @@
"""Programmable Constraint Integration Tests."""

import pandas as pd
import pytest

from sdv.cag import FixedCombinations, ProgrammableConstraint, SingleTableProgrammableConstraint
from sdv.cag import FixedCombinations, ProgrammableConstraint
from sdv.datasets.demo import download_demo
from sdv.metadata import Metadata
from sdv.multi_table import HMASynthesizer
from sdv.single_table import GaussianCopulaSynthesizer


class SingleTableIfTrueThenZero(SingleTableProgrammableConstraint):
class SingleTableIfTrueThenZero(ProgrammableConstraint):
"""Custom constraint that ensures that if a flag column is True."""

def __init__(self, target_column, flag_column):
def __init__(self, target_column, flag_column, table_name=None):
self.target_column = target_column
self.flag_column = flag_column
self.table_name = table_name

def validate(self, metadata):
table_name = metadata._get_single_table_name()
assert metadata.tables[table_name].columns[self.target_column]['sdtype'] == 'numerical'
assert metadata.tables[table_name].columns[self.flag_column]['sdtype'] == 'boolean'
assert metadata.tables[self.table_name].columns[self.target_column]['sdtype'] == 'numerical'
assert metadata.tables[self.table_name].columns[self.flag_column]['sdtype'] == 'boolean'

def validate_input_data(self, data):
return

def transform(self, data):
"""Transform the data if amenities fee is to be applied."""
typical_value = data[self.target_column].median()
data[self.target_column] = data[self.target_column].mask(
data[self.flag_column], typical_value
typical_value = data[self.table_name][self.target_column].median()
data[self.table_name][self.target_column] = data[self.table_name][self.target_column].mask(
data[self.table_name][self.flag_column], typical_value
)

return data

def reverse_transform(self, transformed_data):
"""Reverse the data if amenities fee is to be applied."""
transformed_data[self.target_column] = transformed_data[self.target_column].mask(
transformed_data[self.flag_column], 0.0
transformed_table = transformed_data[self.table_name]
transformed_table[self.target_column] = transformed_table[self.target_column].mask(
transformed_table[self.flag_column], 0.0
)

transformed_data[self.table_name] = transformed_table
return transformed_data

def get_updated_metadata(self, metadata):
return metadata

def is_valid(self, synthetic_data):
"""Validate that if ``has_rewards`` amenities fee is 0."""
true_values = (synthetic_data[self.flag_column]) & (
synthetic_data[self.target_column] == 0.0
is_valid = {
table: pd.Series(True, index=synthetic_data[table].index) for table in synthetic_data
}

true_values = (synthetic_data[self.table_name][self.flag_column]) & (
synthetic_data[self.table_name][self.target_column] == 0.0
)
false_values = ~synthetic_data[self.flag_column]
return (true_values) | (false_values)
false_values = ~synthetic_data[self.table_name][self.flag_column]
is_valid[self.table_name] = (true_values) | (false_values)
return is_valid


@pytest.fixture
def programmable_constraint():
class MyConstraint(ProgrammableConstraint):
def __init__(self, column_names, table_name):
def __init__(self, column_names, table_name=None):
self.column_names = column_names
self.table_name = table_name
self._joint_column = '#'.join(self.column_names)
Expand All @@ -75,6 +83,7 @@ def validate_input_data(self, data):
def fit(self, data, metadata):
self.metadata = metadata
FixedCombinations._fit(self, data, metadata)
self._fitted = True

def transform(self, data):
return FixedCombinations._transform(self, data)
Expand All @@ -86,60 +95,13 @@ def reverse_transform(self, transformed_data):
return FixedCombinations._reverse_transform(self, transformed_data)

def is_valid(self, synthetic_data):
return FixedCombinations._is_valid(self, synthetic_data)
return FixedCombinations._is_valid(self, synthetic_data, self.metadata)

return MyConstraint


@pytest.fixture
def single_table_programmable_constraint():
class MyConstraint(SingleTableProgrammableConstraint):
def __init__(self, column_names, table_name):
self.column_names = column_names
self.table_name = table_name
self._joint_column = '#'.join(self.column_names)
self._combinations = None
self._fitted = False

def _get_single_table_name(self, metadata):
# Have to define this so that we can re-use existing methods on the constraint
return self.table_name

def validate(self, metadata):
FixedCombinations._validate_constraint_with_metadata(self, metadata)

def validate_input_data(self, data):
return

def fit(self, data, metadata):
self.metadata = metadata
data = {self.table_name: data}
FixedCombinations._fit(self, data, metadata)
self._fitted = True

def transform(self, data):
data = {self.table_name: data}
transformed = FixedCombinations._transform(self, data)
return transformed[self.table_name]

def get_updated_metadata(self, metadata):
return FixedCombinations._get_updated_metadata(self, metadata)

def reverse_transform(self, transformed_data):
transformed_data = {self.table_name: transformed_data}
reverse_transformed = FixedCombinations._reverse_transform(self, transformed_data)
return reverse_transformed[self.table_name]

def is_valid(self, synthetic_data):
synthetic_data = {self.table_name: synthetic_data}
is_valid = FixedCombinations._is_valid(self, synthetic_data, self.metadata)
return is_valid[self.table_name]

return MyConstraint


def test_end_to_end_programmable_constraint(programmable_constraint):
"""Test using a programmable constraint with a synthesizer end-to-end."""
def test_end_to_end_programmable_constraint_multi_table(programmable_constraint):
"""Test using a programmable constraint with a multi table synthesizer end-to-end."""
data, metadata = download_demo('multi_table', 'fake_hotels')
my_constraint = programmable_constraint(
column_names=['has_rewards', 'room_type'], table_name='guests'
Expand All @@ -161,10 +123,10 @@ def test_end_to_end_programmable_constraint(programmable_constraint):
assert isinstance(constraints[0], programmable_constraint)


def test_end_to_end_single_table_programmable_constraint(single_table_programmable_constraint):
"""Test using a single table programmable constraint with a synthesizer end-to-end."""
def test_end_to_end_programmable_constraint_single_table(programmable_constraint):
"""Test using a programmable constraint with a single table synthesizer end-to-end."""
data, metadata = download_demo('single_table', 'fake_hotel_guests')
my_constraint = single_table_programmable_constraint(
my_constraint = programmable_constraint(
column_names=['has_rewards', 'room_type'], table_name='fake_hotel_guests'
)
synthesizer = GaussianCopulaSynthesizer(metadata)
Expand All @@ -178,7 +140,7 @@ def test_end_to_end_single_table_programmable_constraint(single_table_programmab
# Assert
original_combinations = set(zip(data['has_rewards'], data['room_type']))
assert set(zip(sampled_data['has_rewards'], sampled_data['room_type'])) == original_combinations
assert isinstance(constraints[0], single_table_programmable_constraint)
assert isinstance(constraints[0], programmable_constraint)


def test_end_to_end_simple_constraint_with_no_fit(programmable_constraint):
Expand All @@ -187,7 +149,9 @@ def test_end_to_end_simple_constraint_with_no_fit(programmable_constraint):
data, metadata = download_demo('single_table', 'fake_hotel_guests')
synthesizer = GaussianCopulaSynthesizer(metadata)
custom_constraint = SingleTableIfTrueThenZero(
target_column='amenities_fee', flag_column='has_rewards'
target_column='amenities_fee',
flag_column='has_rewards',
table_name=metadata._get_single_table_name(),
)

# Run
Expand All @@ -201,7 +165,7 @@ def test_end_to_end_simple_constraint_with_no_fit(programmable_constraint):
assert all((true_values) | (false_values))


def test_get_updated_metadata_is_passed_metadata_copy(single_table_programmable_constraint):
def test_get_updated_metadata_is_passed_metadata_copy(programmable_constraint):
"""Test that the ``get_updated_metadata`` is not given the original metadata."""
# Setup
data, metadata = download_demo('single_table', 'fake_hotel_guests')
Expand All @@ -214,8 +178,8 @@ def get_updated_metadata(self, metadata):
metadata.add_column(self._joint_column, sdtype='categorical')
return metadata

single_table_programmable_constraint.get_updated_metadata = get_updated_metadata
my_constraint = single_table_programmable_constraint(
programmable_constraint.get_updated_metadata = get_updated_metadata
my_constraint = programmable_constraint(
column_names=['has_rewards', 'room_type'], table_name='fake_hotel_guests'
)

Expand Down
Loading
Loading