Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 226 additions & 31 deletions ctgan/synthesizers/ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,28 @@ class CTGAN(BaseSynthesizer):
**Deprecated** Whether to attempt to use cuda for GPU computation.
If this is False or CUDA is not available, CPU will be used.
Defaults to ``True``.
adaptive_training (bool):
Whether to use adaptive discriminator-generator step balancing.
When enabled, discriminator_steps will be adjusted dynamically based on
loss convergence. Defaults to ``False``.
gradient_clipping (float):
Maximum gradient norm for gradient clipping. If None, no clipping is applied.
Defaults to ``None``.
early_stopping (bool):
Whether to enable early stopping based on loss convergence.
Defaults to ``False``.
early_stopping_patience (int):
Number of epochs to wait before early stopping if no improvement.
Defaults to 10.
adaptive_lr (bool):
Whether to use adaptive learning rate scheduling based on loss plateaus.
Defaults to ``False``.
lr_patience (int):
Number of epochs to wait before reducing learning rate.
Defaults to 5.
lr_factor (float):
Factor by which learning rate is reduced.
Defaults to 0.5.
"""

def __init__(
Expand All @@ -165,6 +187,13 @@ def __init__(
pac=10,
enable_gpu=True,
cuda=None,
adaptive_training=False,
gradient_clipping=None,
early_stopping=False,
early_stopping_patience=10,
adaptive_lr=False,
lr_patience=5,
lr_factor=0.5,
):
assert batch_size % 2 == 0

Expand All @@ -179,6 +208,7 @@ def __init__(

self._batch_size = batch_size
self._discriminator_steps = discriminator_steps
self._base_discriminator_steps = discriminator_steps
self._log_frequency = log_frequency
self._verbose = verbose
self._epochs = epochs
Expand All @@ -190,6 +220,23 @@ def __init__(
self._generator = None
self.loss_values = None

# Adaptive training parameters
self._adaptive_training = adaptive_training
self._gradient_clipping = gradient_clipping
self._early_stopping = early_stopping
self._early_stopping_patience = early_stopping_patience
self._adaptive_lr = adaptive_lr
self._lr_patience = lr_patience
self._lr_factor = lr_factor

# Training state tracking
self._best_loss = float('inf')
self._patience_counter = 0
self._lr_patience_counter = 0
self._generator_grad_norms = []
self._discriminator_grad_norms = []
self._loss_history = []

@staticmethod
def _gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
"""Deals with the instability of the gumbel_softmax for older versions of torch.
Expand Down Expand Up @@ -312,6 +359,108 @@ def _validate_null_data(self, train_data, discrete_columns):
'Please remove all null values from your continuous training data.'
)

def _compute_gradient_norm(self, model):
"""Compute the gradient norm of a model.

Args:
model (torch.nn.Module): The model to compute gradients for.

Returns:
float: The gradient norm.
"""
total_norm = 0.0
for param in model.parameters():
if param.grad is not None:
param_norm = param.grad.data.norm(2)
total_norm += param_norm.item() ** 2
return total_norm ** (1.0 / 2)

def _clip_gradients(self, model):
"""Clip gradients of a model if gradient_clipping is enabled.

Args:
model (torch.nn.Module): The model to clip gradients for.
"""
if self._gradient_clipping is not None:
torch.nn.utils.clip_grad_norm_(model.parameters(), self._gradient_clipping)

def _adapt_discriminator_steps(self, gen_loss, disc_loss):
"""Adaptively adjust discriminator steps based on loss balance.

Args:
gen_loss (float): Current generator loss.
disc_loss (float): Current discriminator loss.

Returns:
int: Adjusted discriminator steps.
"""
if not self._adaptive_training:
return self._discriminator_steps

# Compute loss ratio
loss_ratio = abs(gen_loss) / (abs(disc_loss) + 1e-8)

# If generator is too strong (low loss), increase discriminator steps
# If discriminator is too strong (high gen loss), decrease discriminator steps
if loss_ratio < 0.5:
# Generator too strong, need more discriminator training
new_steps = min(self._base_discriminator_steps + 1, 5)
elif loss_ratio > 2.0:
# Discriminator too strong, reduce discriminator training
new_steps = max(self._base_discriminator_steps - 1, 1)
else:
new_steps = self._base_discriminator_steps

self._discriminator_steps = new_steps
return self._discriminator_steps

def _check_early_stopping(self, current_loss):
"""Check if early stopping criteria is met.

Args:
current_loss (float): Current epoch loss.

Returns:
bool: True if training should stop, False otherwise.
"""
if not self._early_stopping:
return False

if current_loss < self._best_loss:
self._best_loss = current_loss
self._patience_counter = 0
return False
else:
self._patience_counter += 1
if self._patience_counter >= self._early_stopping_patience:
return True
return False

def _adapt_learning_rate(self, optimizer, current_loss):
"""Adaptively adjust learning rate based on loss plateau.

Args:
optimizer (torch.optim.Optimizer): The optimizer to adjust.
current_loss (float): Current epoch loss.
"""
if not self._adaptive_lr:
return

if len(self._loss_history) > 0:
if current_loss >= min(self._loss_history[-self._lr_patience:]):
self._lr_patience_counter += 1
else:
self._lr_patience_counter = 0

if self._lr_patience_counter >= self._lr_patience:
for param_group in optimizer.param_groups:
old_lr = param_group['lr']
new_lr = old_lr * self._lr_factor
param_group['lr'] = new_lr
if self._verbose:
print(f'Reducing learning rate from {old_lr:.6f} to {new_lr:.6f}')
self._lr_patience_counter = 0

@random_state
def fit(self, train_data, discrete_columns=(), epochs=None):
"""Fit the CTGAN Synthesizer models to the training data.
Expand Down Expand Up @@ -339,6 +488,15 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
DeprecationWarning,
)

# Reset training state
self._best_loss = float('inf')
self._patience_counter = 0
self._lr_patience_counter = 0
self._generator_grad_norms = []
self._discriminator_grad_norms = []
self._loss_history = []
self._discriminator_steps = self._base_discriminator_steps

self._transformer = DataTransformer()
self._transformer.fit(train_data, discrete_columns)

Expand Down Expand Up @@ -386,6 +544,12 @@ def fit(self, train_data, discrete_columns=(), epochs=None):

steps_per_epoch = max(len(train_data) // self._batch_size, 1)
for i in epoch_iterator:
# Adapt discriminator steps at the start of each epoch
if i > 0 and self._adaptive_training:
prev_gen_loss = self.loss_values.iloc[-1]['Generator Loss'] if not self.loss_values.empty else 0
prev_disc_loss = self.loss_values.iloc[-1]['Discriminator Loss'] if not self.loss_values.empty else 0
self._adapt_discriminator_steps(prev_gen_loss, prev_disc_loss)

for id_ in range(steps_per_epoch):
for n in range(self._discriminator_steps):
fakez = torch.normal(mean=mean, std=std)
Expand Down Expand Up @@ -432,6 +596,9 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
optimizerD.zero_grad(set_to_none=False)
pen.backward(retain_graph=True)
loss_d.backward()
self._clip_gradients(discriminator)
disc_grad_norm = self._compute_gradient_norm(discriminator)
self._discriminator_grad_norms.append(disc_grad_norm)
optimizerD.step()

fakez = torch.normal(mean=mean, std=std)
Expand Down Expand Up @@ -462,10 +629,14 @@ def fit(self, train_data, discrete_columns=(), epochs=None):

optimizerG.zero_grad(set_to_none=False)
loss_g.backward()
self._clip_gradients(self._generator)
gen_grad_norm = self._compute_gradient_norm(self._generator)
self._generator_grad_norms.append(gen_grad_norm)
optimizerG.step()

generator_loss = loss_g.detach().cpu().item()
discriminator_loss = loss_d.detach().cpu().item()
combined_loss = abs(generator_loss) + abs(discriminator_loss)

epoch_loss_df = pd.DataFrame({
'Epoch': [i],
Expand All @@ -479,6 +650,19 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
else:
self.loss_values = epoch_loss_df

# Track loss history for adaptive learning rate
self._loss_history.append(combined_loss)

# Adaptive learning rate
self._adapt_learning_rate(optimizerG, combined_loss)
self._adapt_learning_rate(optimizerD, combined_loss)

# Early stopping check
if self._check_early_stopping(combined_loss):
if self._verbose:
print(f'Early stopping triggered at epoch {i}')
break

if self._verbose:
epoch_iterator.set_description(
description.format(
Expand Down Expand Up @@ -506,43 +690,54 @@ def sample(self, n, condition_column=None, condition_value=None):
Returns:
numpy.ndarray or pandas.DataFrame
"""
if condition_column is not None and condition_value is not None:
condition_info = self._transformer.convert_column_name_value_to_id(
condition_column, condition_value
)
global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
condition_info, self._batch_size
)
else:
global_condition_vec = None

steps = n // self._batch_size + 1
data = []
for i in range(steps):
mean = torch.zeros(self._batch_size, self._embedding_dim)
std = mean + 1
fakez = torch.normal(mean=mean, std=std).to(self._device)
# Set generator to eval mode for consistent sampling behavior
was_training = self._generator.training if self._generator is not None else False
if self._generator is not None:
self._generator.eval()

if global_condition_vec is not None:
condvec = global_condition_vec.copy()
try:
if condition_column is not None and condition_value is not None:
condition_info = self._transformer.convert_column_name_value_to_id(
condition_column, condition_value
)
global_condition_vec = self._data_sampler.generate_cond_from_condition_column_info(
condition_info, self._batch_size
)
else:
condvec = self._data_sampler.sample_original_condvec(self._batch_size)
global_condition_vec = None

steps = n // self._batch_size + 1
data = []
with torch.no_grad():
for i in range(steps):
mean = torch.zeros(self._batch_size, self._embedding_dim)
std = mean + 1
fakez = torch.normal(mean=mean, std=std).to(self._device)

if global_condition_vec is not None:
condvec = global_condition_vec.copy()
else:
condvec = self._data_sampler.sample_original_condvec(self._batch_size)

if condvec is None:
pass
else:
c1 = condvec
c1 = torch.from_numpy(c1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)
if condvec is None:
pass
else:
c1 = condvec
c1 = torch.from_numpy(c1).to(self._device)
fakez = torch.cat([fakez, c1], dim=1)

fake = self._generator(fakez)
fakeact = self._apply_activate(fake)
data.append(fakeact.detach().cpu().numpy())
fake = self._generator(fakez)
fakeact = self._apply_activate(fake)
data.append(fakeact.detach().cpu().numpy())

data = np.concatenate(data, axis=0)
data = data[:n]
data = np.concatenate(data, axis=0)
data = data[:n]

return self._transformer.inverse_transform(data)
return self._transformer.inverse_transform(data)
finally:
# Restore generator training mode
if self._generator is not None and was_training:
self._generator.train()

def set_device(self, device):
"""Set the `device` to be used ('GPU' or 'CPU)."""
Expand Down
Loading