Skip to content

Conversation

@sarakodeiri
Copy link
Collaborator

PR Type

[Feature | Fix]

Short Description

Clickup Ticket(s): https://app.clickup.com/t/868g60gxr

Refactored classification pipeline and classifier classes.
Added final classifier training (somehow missed it in the previous PR, sorry.)
Added basic inference to get predictions on challenge data (no comparison with ground truth yet.)

Tests Added

Modified tests to sync with the refactoring.
No new tests. Will add some when the inference evaluation is complete.

@coderabbitai
Copy link

coderabbitai bot commented Jan 24, 2026

📝 Walkthrough

Walkthrough

This pull request introduces an end-to-end inference pipeline for the EPT attack framework. Changes include adding configuration flags for inference execution and results storage, implementing helper functions for training summarization and best attack classifier persistence, and extending the MLPClassifier with training and prediction methods. A new run_inference stage loads trained classifiers, processes inference data, and saves predictions. Tests are refactored from mock-based to fixture-driven approaches to validate the updated classifier behavior.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2
❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 76.92% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'Sk/ept classifier' is vague and uses a branch name prefix; it doesn't clearly convey the main change. Revise the title to be more descriptive, e.g., 'Add classifier training and inference pipeline for EPT attack' to clearly summarize the primary changes.
✅ Passed checks (1 passed)
Check name Status Explanation
Description check ✅ Passed The description covers all required template sections (PR Type, Short Description with ticket link, and Tests Added) with sufficient detail about refactoring, new features, and test modifications.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/ept_attack/run_ept_attack.py (2)

108-140: Critical: _summarize_and_save_training_results is defined twice.

The function _summarize_and_save_training_results is defined at lines 108-140 and again at lines 275-307. This is a merge/rebase artifact. The second definition (lines 275-307) will shadow the first. Remove one of the duplicate definitions.

Also applies to: 275-307


144-273: Critical: run_attack_classifier_training is defined twice.

The function run_attack_classifier_training is defined at lines 144-273 and again at lines 362-497. The static analysis tool also flags this (F811). The second definition appears to be the intended version as it includes the call to _train_and_save_best_attack_classifier. Remove the first definition (lines 144-273).

Also applies to: 362-497

🤖 Fix all issues with AI agents
In `@examples/ept_attack/run_ept_attack.py`:
- Around line 531-535: The inference currently concatenates CSVs into
df_inference_features and calls trained_model.predict directly, but training
used filter_data with best_column_types in train_attack_classifier; update
inference to apply the same filtering: ensure best_column_types is persisted
with the model or passed into run_inference, then call
filter_data(df_inference_features, best_column_types) before calling
trained_model.predict; reference trained_model.predict, train_attack_classifier,
filter_data, best_column_types and run_inference to locate and update the code
paths.
- Around line 583-586: In main(), there's a duplicated conditional block that
re-checks config.pipeline.run_inference and
config.pipeline.run_attack_classifier_training; remove the second occurrence so
each pipeline check is only executed once—locate the duplicate if
config.pipeline.run_inference: ... and if
config.pipeline.run_attack_classifier_training: ... (referencing
run_inference(config) and run_attack_classifier_training(config)) and delete
that repeated block, leaving the original checks intact.
- Around line 551-552: The comment contains a typo: change "challenege" to
"challenge" in the TODO line referencing evaluation of inference results; update
the TODO comment text that mentions evaluating inference results using the
challenge labels and ensure the referenced function call
_evaluate_inference_results(predictions, diffusion_model_name) remains correctly
spelled and uncommented/implemented when you add the evaluation logic.
- Around line 11-15: Remove duplicate imports: keep a single import for pickle,
one for "from collections import defaultdict", and one for "from datetime import
datetime"; edit the top of run_ept_attack.py to eliminate the repeated lines
referencing pickle, defaultdict, and datetime so only one import statement per
symbol remains and ensure no other code depends on the removed duplicates.

In `@tests/unit/attacks/ept_attack/test_classification.py`:
- Around line 238-241: The model fixture currently hardcodes input_dim=20 and
hidden_dim=10 in its signature instead of using the existing input_dim and
hidden_dim fixtures; update the model fixture declaration to accept input_dim
and hidden_dim as parameters (remove the default values) and construct the
MLPClassifier using those injected fixture values (e.g.,
MLPClassifier(input_size=input_dim, hidden_size=hidden_dim, ...)) so changes to
the input_dim/hidden_dim fixtures propagate to tests that depend on model.
🧹 Nitpick comments (5)
examples/ept_attack/config.yaml (1)

12-13: Minor formatting: Add blank line before the pipeline control section.

There's a missing blank line between the new inference_results_path and the # Pipeline control comment, which would improve readability and maintain consistency with the rest of the file.

Suggested fix
   inference_results_path: ${data_paths.output_data_path}/inference_results # Path to save inference results
+
 # Pipeline control
src/midst_toolkit/attacks/ept/classification.py (1)

122-142: Consider mini-batch training for scalability.

The current implementation passes all training data through the network as a single batch per epoch. For large datasets, this could exhaust GPU memory. Consider adding a batch_size parameter with mini-batch iteration for better scalability, or document the expected dataset size limitations.

tests/unit/attacks/ept_attack/test_classification.py (1)

254-260: Consider consolidating with existing test_mlp_classifier.

This test duplicates much of the layer dimension checking already done in test_mlp_classifier (lines 111-145). Consider either consolidating these tests or differentiating their purposes more clearly (e.g., this one could focus on the new epochs/device parameters).

examples/ept_attack/run_ept_attack.py (2)

522-523: Note: Security consideration with pickle deserialization.

Using pickle.load() on files can be a security risk if the file source is untrusted (arbitrary code execution). In this context, the model files are generated by the same pipeline, so the risk is mitigated. However, consider documenting this assumption or using a safer serialization format (e.g., torch.save/torch.load for PyTorch models, or joblib with explicit trust flags) for defense in depth.


499-549: Consider adding logging for model loading.

The inference function logs when saving results but doesn't log when successfully loading the trained model. Adding a log statement after loading would help with debugging and monitoring.

Suggested addition
         with open(model_path, "rb") as file:
             trained_model = pickle.load(file)
+
+        log(INFO, f"Loaded trained attack classifier from {model_path}")

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good I think. I'd make sure to check the coderabbit comments. At least one of them seems like something we should at least make sure works and maybe have a comment explaining why it's okay.

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes. Still a few housekeeping things to consider, but nearly there.

("single_table", valid_single_table_models)
if is_single_table
else ("multi_table", valid_multi_table_models)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean by this being a bit clunky. Maybe I'm missing something, but could you just do:

def run_inference(config: DictConfig, diffusion_model_name_override: list[str] | None = None) -> None:
    """
    Runs inference using the trained attack classifier on the challenge data.

    Args:
        config: Configuration object set in config.yaml.
        diffusion_model_name_override: If provided and valid, runs inference
            only for these model. If None or invalid, runs for all applicable models.

    Throws:
        FileNotFoundError: If the trained attack classifier model file is not found.
    """
    log(INFO, "Running inference with the trained attack classifier.")

    # Determine which diffusion models to run inference on. If an override is provided
    # and valid, use that; otherwise, use all applicable models based on the specified
    # data format (single-table or multi-table).

    is_single_table = config.attack_settings.single_table
    default_single_table_models = ["tabddpm", "tabsyn"]
    default_multi_table_models = ["clavaddpm"]

    data_format = "single_table" if is_single_table else "multi_table"

    if diffusion_model_name_override is not None:
        diffusion_models = diffusion_model_name_override
    elif is_single_table:
        diffusion_models = default_single_table_models
    else:
        diffusion_models = default_multi_table_models

If the models provided are "bad" in that they don't exist, then you're loading will fail below.

Copy link
Collaborator

@emersodb emersodb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the changes look good. Just a suggestion to simplify your override approach.

@sarakodeiri sarakodeiri merged commit 09f0078 into main Feb 4, 2026
6 checks passed
@sarakodeiri sarakodeiri deleted the sk/ept-classifier branch February 4, 2026 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants