This repository contains code for the paper Robust Weight Imprinting: Insights from Neural Collapse and Proxy-Based Aggregation. [Preprint] [TMLR Publication] [Video Presentation]
We test frozen, neurally collapsed foundation models (FMs) on transferability to new classes.
The weight generator (GEN) uses training data from a novel task T to consecutively generate one or more weight vectors (proxies) per class T.
In inference, the final output for the test data in T is computed by an aggregation (AGG) mechanism.
Embeddings and generated weights are normalized according to NORMpre and NORMpost, respectively.
During inference, embeddings are normalized according to NORMinf.
For more details on the framework implementation and components, please refer to the paper.
There are three main ways to run the experiments reproducing the results in our paper with this repository:
- Local setup with a virtual environment
- Docker container
- (in conjunction with 2.) Kubernetes cluster for large-scale parallel execution
# Tested with Python3.10
# Create and activate virtual environment
python -m venv venv
source venv/bin/activate # On macOS/Linux or venv\Scripts\activate on Windows
# Install requirements
pip install -r requirements.txt
# Install the package in development mode
pip install -e .
# Generate embeddings (required before running experiments)
python scripts/generate_embeddings.py
# Now you can run experiments, e.g., reproduce the results from the first results subsection
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec1_aggfixed.yamlTo run this across a Linux cluster with Kubernetes for large-scale parallel execution, build a Docker container via
docker build -t imprinting . --platform=linux/amd64push it to your desired registry via
docker push <registry-name>/imprintingand then, use (for example) the Kubernetes job generator:
# Navigate to the k8s directory
cd k8s
# Generate job files for a specific configuration
python imprinting_jobs_generator.py
# Apply the generated job files
kubectl apply -f generated_imprinting_jobs_reprod/The jobs will run in parallel on your Kubernetes cluster, with results stored in the configured persistent volume.
To reproduce all the experiments from our paper, run the following configuration files using the run_imprinting_experiments.py script:
# Section 5.1 (including Table 1)
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec1_aggfixed.yaml
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec1_aggfixed_kls.yaml
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec1_aggfocus.yaml
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec1_aggfocus_kls.yaml
# Figure 6
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_fig6.yaml
# Section 5.2
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec2.yaml
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec2_kls.yaml
# Section 5.3 for ImageNet
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec3_imagenet.yaml
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec3_imagenet_kls.yaml
# Section 5.3 for CombiDigits dataset
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec3_combidigits.yaml
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec3_combidigits_kls.yaml
# Section 5.3 for other datasets
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec3_non-imagenet.yaml
python scripts/run_imprinting_experiments.py --config src/config/config_reprod_subsec3_non-imagenet_kls.yamlThe neural collapse experiments provide insights into the benefits of multi-proxy imprinting and are shown in the section 5.1 of the paper:
python scripts/run_neural_collapse_experiments.pyThis script calculates the NC1 metric for MNIST, FashionMNIST, CIFAR-10, the MNIST&MNIST-M&USPS&SVHN mixed set (CombiDigits), and ImageNet with different label remappings.
After running the experiments, use the tools in the analysis/ directory to process and visualize the results:
# Navigate to the analysis directory
cd analysis
# Run the analysis notebook
jupyter notebook analysis.ipynbThe critical difference diagram generation in cd_diag.py performs statistical significance testing to compare different imprinting configurations across multiple datasets and backbones (see section 3.3 in the paper for detailed explanation).
Within IMPRINT, we find the best method by investigating average rank, average accuracy, and statistical significance in ranking (dis-)agreements through critical difference diagrams with FMs (resnet18, resnet50, vit_b_16, and swin_b), all of which are pretrained on ImageNet-1K, and twelve transfer learning tasks T coming from MNIST, FashionMNIST, and CIFAR-10.
Previously studied imprinting strategies are special cases within IMPRINT.
The framework enables the creation of a novel configuration ("Ours") that outperforms previous work across FMs and Ts by a large margin with statistical significance.
Here, mean imprinting and this oracle baseline.
| Paper | NORMpre |
GEN |
NORMpost |
NORMinf |
AGG |
Avg. acc. % |
|---|---|---|---|---|---|---|
| Qi et al. | L2 | mean | L2 | L2 | max | 86.79 |
| Hosoda et al. | none | mean | quantile | none | max | 82.90 |
| Janson et al. | none | mean | none | none | 1-nn | 86.64 |
| Ours | L2 | k-means | L2 | L2 | max | 91.06 |
| Oracle | none | least-squares | none | none | max | 94.54 |
The central effect of using multi-proxy imprinting with MNIST, FashionMNIST, CIFAR-10, resp. CombiDigits at once are shown in dotted lines and show that using one proxy (the mean) is not optimal, as the FM seems to not be fully collapsed on these OOD classes.
This confirms the connection between the effect of using multiple proxies and the collapse of the data.
In the paper, especially in Figure 12, this is explored further.
If you find this work and/or repository useful for your research, please consider citing our paper:
@article{
westerhoff2025robust,
title={Robust Weight Imprinting: Insights from Neural Collapse and Proxy-Based Aggregation},
author={Justus Westerhoff and Golzar Atefi and Mario Koddenbrock and Alexei Figueroa and Alexander L{\"o}ser and Erik Rodner and Felix Alexander Gers},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2025},
url={https://openreview.net/forum?id=duU11BnQ3Y},
}