Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ctgan/synthesizers/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ def _set_device(enable_gpu, device=None):
def validate_and_set_device(enable_gpu, cuda):
enable_gpu = get_enable_gpu_value(enable_gpu, cuda)
return _set_device(enable_gpu)


def _format_score(score):
"""Format a score as a fixed-length string ``±XX.XX``.

Values are clipped to the range ``[-99.99, +99.99]`` so the result
is always exactly 6 characters.
"""
score = max(-99.99, min(99.99, score))
sign = '+' if score >= 0 else '-'
return f'{sign}{abs(score):05.2f}'
13 changes: 9 additions & 4 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ctgan.data_sampler import DataSampler
from ctgan.data_transformer import DataTransformer
from ctgan.errors import InvalidDataError
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
from ctgan.synthesizers._utils import _format_score, _set_device, validate_and_set_device
from ctgan.synthesizers.base import BaseSynthesizer, random_state


Expand Down Expand Up @@ -379,8 +379,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):

epoch_iterator = tqdm(range(epochs), disable=(not self._verbose))
if self._verbose:
description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})'
epoch_iterator.set_description(description.format(gen=0, dis=0))
description = 'Gen. ({gen}) | Discrim. ({dis})'
epoch_iterator.set_description(
description.format(gen=_format_score(0), dis=_format_score(0))
)

steps_per_epoch = max(len(train_data) // self._batch_size, 1)
for i in epoch_iterator:
Expand Down Expand Up @@ -479,7 +481,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):

if self._verbose:
epoch_iterator.set_description(
description.format(gen=generator_loss, dis=discriminator_loss)
description.format(
gen=_format_score(generator_loss),
dis=_format_score(discriminator_loss),
)
)

@random_state
Expand Down
8 changes: 4 additions & 4 deletions ctgan/synthesizers/tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm import tqdm

from ctgan.data_transformer import DataTransformer
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
from ctgan.synthesizers._utils import _format_score, _set_device, validate_and_set_device
from ctgan.synthesizers.base import BaseSynthesizer, random_state


Expand Down Expand Up @@ -161,8 +161,8 @@ def fit(self, train_data, discrete_columns=()):
self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
iterator = tqdm(range(self.epochs), disable=(not self.verbose))
if self.verbose:
iterator_description = 'Loss: {loss:.3f}'
iterator.set_description(iterator_description.format(loss=0))
iterator_description = 'Loss: {loss}'
iterator.set_description(iterator_description.format(loss=_format_score(0)))

for i in iterator:
loss_values = []
Expand Down Expand Up @@ -205,7 +205,7 @@ def fit(self, train_data, discrete_columns=()):

if self.verbose:
iterator.set_description(
iterator_description.format(loss=loss.detach().cpu().item())
iterator_description.format(loss=_format_score(loss.detach().cpu().item()))
)

@random_state
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/synthesizer/test_tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,5 @@ def test_tvae_save(tmpdir, capsys):
assert len(loss_values) == 10
assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'}
assert all(loss_values['Batch'] == 0)
last_loss_val = loss_values['Loss'].iloc[-1]
assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out
last_loss_val = max(-99.99, min(99.99, loss_values['Loss'].iloc[-1]))
assert f'Loss: {last_loss_val:+06.2f}: 100%' in captured_out
28 changes: 27 additions & 1 deletion tests/unit/synthesizer/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
import pytest
import torch

from ctgan.synthesizers._utils import _set_device, get_enable_gpu_value, validate_and_set_device
from ctgan.synthesizers._utils import (
_format_score,
_set_device,
get_enable_gpu_value,
validate_and_set_device,
)


def test__validate_gpu_parameter():
Expand Down Expand Up @@ -61,6 +66,27 @@ def test__set_device():
assert device_4 == torch.device('cpu')


@pytest.mark.parametrize(
'score, expected',
[
(0, '+00.00'),
(1.233434, '+01.23'),
(-0.93, '-00.93'),
(0.01, '+00.01'),
(-1.21, '-01.21'),
(99.99, '+99.99'),
(-99.99, '-99.99'),
(150, '+99.99'),
(-200, '-99.99'),
],
)
def test__format_score(score, expected):
"""Test the ``_format_score`` method."""
result = _format_score(score)
assert result == expected
assert len(result) == 6


@patch('ctgan.synthesizers._utils._set_device')
@patch('ctgan.synthesizers._utils.get_enable_gpu_value')
def test_validate_and_set_device(mock_validate, mock_set_device):
Expand Down
15 changes: 15 additions & 0 deletions tests/unit/synthesizer/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,21 @@ def test__cond_loss(self):

assert (result - expected).abs() < 1e-3

@patch('ctgan.synthesizers.ctgan._format_score')
def test_fit_verbose_calls_format_score(self, format_score_mock):
"""Test that ``_format_score`` is called during verbose fitting."""
# Setup
format_score_mock.side_effect = lambda x: f'+{abs(x):05.2f}'
data = pd.DataFrame({'col1': [0, 1, 2, 3, 4], 'col2': ['a', 'b', 'c', 'a', 'b']})

# Run
ctgan = CTGAN(epochs=1, verbose=True)
ctgan.fit(data, discrete_columns=['col2'])

# Assert
assert format_score_mock.call_count == 4
format_score_mock.assert_any_call(0)

def test__validate_discrete_columns(self):
"""Test `_validate_discrete_columns` if the discrete column doesn't exist.

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/synthesizer/test_tvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ def mock_add(a, b):

# Assert
tqdm_mock.assert_called_once_with(range(epochs), disable=False)
assert iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000')
assert iterator_mock.set_description.call_args_list[1] == call('Loss: 1.235')
assert iterator_mock.set_description.call_args_list[0] == call('Loss: +00.00')
assert iterator_mock.set_description.call_args_list[1] == call('Loss: +01.23')
assert iterator_mock.set_description.call_count == 2