-
Notifications
You must be signed in to change notification settings - Fork 1
Sk/ept classifier #120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sk/ept classifier #120
Conversation
📝 WalkthroughWalkthroughThis pull request introduces an end-to-end inference pipeline for the EPT attack framework. Changes include adding configuration flags for inference execution and results storage, implementing helper functions for training summarization and best attack classifier persistence, and extending the MLPClassifier with training and prediction methods. A new Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/ept_attack/run_ept_attack.py (2)
108-140: Critical:_summarize_and_save_training_resultsis defined twice.The function
_summarize_and_save_training_resultsis defined at lines 108-140 and again at lines 275-307. This is a merge/rebase artifact. The second definition (lines 275-307) will shadow the first. Remove one of the duplicate definitions.Also applies to: 275-307
144-273: Critical:run_attack_classifier_trainingis defined twice.The function
run_attack_classifier_trainingis defined at lines 144-273 and again at lines 362-497. The static analysis tool also flags this (F811). The second definition appears to be the intended version as it includes the call to_train_and_save_best_attack_classifier. Remove the first definition (lines 144-273).Also applies to: 362-497
🤖 Fix all issues with AI agents
In `@examples/ept_attack/run_ept_attack.py`:
- Around line 531-535: The inference currently concatenates CSVs into
df_inference_features and calls trained_model.predict directly, but training
used filter_data with best_column_types in train_attack_classifier; update
inference to apply the same filtering: ensure best_column_types is persisted
with the model or passed into run_inference, then call
filter_data(df_inference_features, best_column_types) before calling
trained_model.predict; reference trained_model.predict, train_attack_classifier,
filter_data, best_column_types and run_inference to locate and update the code
paths.
- Around line 583-586: In main(), there's a duplicated conditional block that
re-checks config.pipeline.run_inference and
config.pipeline.run_attack_classifier_training; remove the second occurrence so
each pipeline check is only executed once—locate the duplicate if
config.pipeline.run_inference: ... and if
config.pipeline.run_attack_classifier_training: ... (referencing
run_inference(config) and run_attack_classifier_training(config)) and delete
that repeated block, leaving the original checks intact.
- Around line 551-552: The comment contains a typo: change "challenege" to
"challenge" in the TODO line referencing evaluation of inference results; update
the TODO comment text that mentions evaluating inference results using the
challenge labels and ensure the referenced function call
_evaluate_inference_results(predictions, diffusion_model_name) remains correctly
spelled and uncommented/implemented when you add the evaluation logic.
- Around line 11-15: Remove duplicate imports: keep a single import for pickle,
one for "from collections import defaultdict", and one for "from datetime import
datetime"; edit the top of run_ept_attack.py to eliminate the repeated lines
referencing pickle, defaultdict, and datetime so only one import statement per
symbol remains and ensure no other code depends on the removed duplicates.
In `@tests/unit/attacks/ept_attack/test_classification.py`:
- Around line 238-241: The model fixture currently hardcodes input_dim=20 and
hidden_dim=10 in its signature instead of using the existing input_dim and
hidden_dim fixtures; update the model fixture declaration to accept input_dim
and hidden_dim as parameters (remove the default values) and construct the
MLPClassifier using those injected fixture values (e.g.,
MLPClassifier(input_size=input_dim, hidden_size=hidden_dim, ...)) so changes to
the input_dim/hidden_dim fixtures propagate to tests that depend on model.
🧹 Nitpick comments (5)
examples/ept_attack/config.yaml (1)
12-13: Minor formatting: Add blank line before the pipeline control section.There's a missing blank line between the new
inference_results_pathand the# Pipeline controlcomment, which would improve readability and maintain consistency with the rest of the file.Suggested fix
inference_results_path: ${data_paths.output_data_path}/inference_results # Path to save inference results + # Pipeline controlsrc/midst_toolkit/attacks/ept/classification.py (1)
122-142: Consider mini-batch training for scalability.The current implementation passes all training data through the network as a single batch per epoch. For large datasets, this could exhaust GPU memory. Consider adding a
batch_sizeparameter with mini-batch iteration for better scalability, or document the expected dataset size limitations.tests/unit/attacks/ept_attack/test_classification.py (1)
254-260: Consider consolidating with existingtest_mlp_classifier.This test duplicates much of the layer dimension checking already done in
test_mlp_classifier(lines 111-145). Consider either consolidating these tests or differentiating their purposes more clearly (e.g., this one could focus on the newepochs/deviceparameters).examples/ept_attack/run_ept_attack.py (2)
522-523: Note: Security consideration with pickle deserialization.Using
pickle.load()on files can be a security risk if the file source is untrusted (arbitrary code execution). In this context, the model files are generated by the same pipeline, so the risk is mitigated. However, consider documenting this assumption or using a safer serialization format (e.g.,torch.save/torch.loadfor PyTorch models, orjoblibwith explicit trust flags) for defense in depth.
499-549: Consider adding logging for model loading.The inference function logs when saving results but doesn't log when successfully loading the trained model. Adding a log statement after loading would help with debugging and monitoring.
Suggested addition
with open(model_path, "rb") as file: trained_model = pickle.load(file) + + log(INFO, f"Loaded trained attack classifier from {model_path}")
emersodb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks pretty good I think. I'd make sure to check the coderabbit comments. At least one of them seems like something we should at least make sure works and maybe have a comment explaining why it's okay.
emersodb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes. Still a few housekeeping things to consider, but nearly there.
| ("single_table", valid_single_table_models) | ||
| if is_single_table | ||
| else ("multi_table", valid_multi_table_models) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see what you mean by this being a bit clunky. Maybe I'm missing something, but could you just do:
def run_inference(config: DictConfig, diffusion_model_name_override: list[str] | None = None) -> None:
"""
Runs inference using the trained attack classifier on the challenge data.
Args:
config: Configuration object set in config.yaml.
diffusion_model_name_override: If provided and valid, runs inference
only for these model. If None or invalid, runs for all applicable models.
Throws:
FileNotFoundError: If the trained attack classifier model file is not found.
"""
log(INFO, "Running inference with the trained attack classifier.")
# Determine which diffusion models to run inference on. If an override is provided
# and valid, use that; otherwise, use all applicable models based on the specified
# data format (single-table or multi-table).
is_single_table = config.attack_settings.single_table
default_single_table_models = ["tabddpm", "tabsyn"]
default_multi_table_models = ["clavaddpm"]
data_format = "single_table" if is_single_table else "multi_table"
if diffusion_model_name_override is not None:
diffusion_models = diffusion_model_name_override
elif is_single_table:
diffusion_models = default_single_table_models
else:
diffusion_models = default_multi_table_models
If the models provided are "bad" in that they don't exist, then you're loading will fail below.
emersodb
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the changes look good. Just a suggestion to simplify your override approach.
PR Type
[Feature | Fix]
Short Description
Clickup Ticket(s): https://app.clickup.com/t/868g60gxr
Refactored classification pipeline and classifier classes.
Added final classifier training (somehow missed it in the previous PR, sorry.)
Added basic inference to get predictions on challenge data (no comparison with ground truth yet.)
Tests Added
Modified tests to sync with the refactoring.
No new tests. Will add some when the inference evaluation is complete.