Skip to content

skyve2012/DBA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Dataset Bias Analysis Framework

Overview

Dataset Bias Analysis (DBA) is a framework for correcting the bias and discrepancies between the training set and the testing set for a higher test set performance given the prevalent of subpopulation shift. This repo conains the implementation code of the framework, which is published as a the paper Boosting Test Performance with Importance Sampling--a Subpopulation Perspective (Shen et al., AAAI 2025).

Preparation

Note that the implementation of the DBA method is based on a hacking of the repo published with the paper Change is Hard: A Closer Look at Subpopulation Shift (Yang et al., ICML 2023). If you use this repo, please consider citing both the DBA paper and this one. For details, please go to the Citation section.

Installation

Run the following commands to create a conda environment for running this code:

git clone git@github.com:skyve2012/DBA.git
cd DBA/
conda env create -f subpopulation_env.yml

Alternatively, one can also refer to the Change is Hard repo to install the package there and refer to subpopulation_env.yml for other missing packages.

Download datasets

There are three datasets discussed in the paper:

To facilitate the implementation, the datasets can be found at the Google Drive.And you need to download and unzip all datasets into different subfolders and put them under one parent folder.

Dataset Bias Correction Method

To start with the correction method, we need to follow the three steps in order.

Model Overfit

In the following, we use ColorMNIST dataset as an example. The same logic can be applied to other datasets. The first step is to overfit the training and the validation set to generate two models. To achieve this, run the following code:

# overfit on the training set 
python ./subpopbench/train.py --algorithm ERM --dataset CMNISTV2 --train_attr no --data_dir path/to/dataset/parent/folder --output_dir path/to/model/folder --output_folder_name output_folder_name --cmnistv2_difficult 2pct 
# overfit on the validation set 
python ./subpopbench/train.py --algorithm ERM --dataset CMNISTV2 --train_attr no --data_dir path/to/dataset/parent/folder --output_dir path/to/model/folder --output_folder_name output_folder_name --cmnistv2_difficult 2pct --switch_train_valid

Obtain Sample Weights

Once the models are trained, we can run the following code to obtain sample weights for the training set. The weights are saved in model.pkl as a dictionary with a key value: aligned_weights. Note that we specify steps=1 to indicate no training. So the script will only obtain model pretrained weights and generate the sample weights for the trianing set.

# using training set model
python ./subpopbench/train.py --algorithm ERM --dataset CMNISTV2 --train_attr no --data_dir path/to/dataset/parent/folder --output_dir path/to/model/folder --output_folder_name output_folder_name --gen_weights --pretrained path/to/model/overfitted/on/training/set --steps 1 --cmnistv2_difficult 2pct --switch_train_valid
# using training set model
python ./subpopbench/train.py --algorithm ERM --dataset CMNISTV2 --train_attr no --data_dir path/to/dataset/parent/folder --output_dir path/to/model/folder --output_folder_name output_folder_name --gen_weights --pretrained path/to/model/overfitted/on/validation/set --steps 1 --cmnistv2_difficult 2pct

Correct the Bias

To correct bias, one need to first convert the sample weights into respective .npy files and run the following code depending on different situations

Known Attribute

python ./subpopbench/train.py --algorithm DBCM --dataset CMNISTV2 --train_attr yes --data_dir path/to/dataset/parent/folder --output_dir path/to/model/folder --output_folder_name output_folder_nam --sample_weights_path_valid  path/to/sample_weights/from/validation_set/model --sample_weights_path path/to/sample_weights/from/train_set/model  --tau_valid 1000. --tau_train 1. --cmnistv2_difficult 2pct --p_maj 0.98

Unkonwn Attribute and Same Train-validation Data Composition

python ./subpopbench/train.py --algorithm DBCM --dataset CMNISTV2 --train_attr no --data_dir path/to/dataset/parent/folder --output_dir path/to/model/folder --output_folder_name output_folder_nam --sample_weights_path_valid  path/to/sample_weights/from/validation_set/model --sample_weights_path path/to/sample_weights/from/train_set/model  --tau_valid 1000. --tau_train 1. --cmnistv2_difficult 2pct --p_maj 0.98

Unkonwn Attribute and Different Train-validation Data Composition

python ./subpopbench/train.py --algorithm DBCM --dataset CMNISTV2 --train_attr no --data_dir path/to/dataset/parent/folder --output_dir path/to/model/folder --output_folder_name output_folder_nam --sample_weights_path_valid  path/to/sample_weights/from/validation_set/model --sample_weights_path path/to/sample_weights/from/train_set/model  --tau_valid 1000. --tau_train 1. --cmnistv2_difficult 2pct --p_maj 0.98 --reverse_logic

Citation

@inproceedings{shen2025subpopulation,
  title={{Boosting Test Performance with Importance Sampling--a Subpopulation Perspective}},
  author={Shen, Hongyu and Zhao, Zhizhen},
  booktitle={he Association for the Advancement
of Artificial Intelligence},
  year={2025}
}
@inproceedings{yang2023change,
  title={Change is Hard: A Closer Look at Subpopulation Shift},
  author={Yang, Yuzhe and Zhang, Haoran and Katabi, Dina and Ghassemi, Marzyeh},
  booktitle={International Conference on Machine Learning},
  year={2023}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages