Skip to content

Error with example fMOW command: incorrect value of "unlabeled_n_groups_per_batch" #152

@joshuafan

Description

@joshuafan

Hello,
If I directly run this command suggested in the README:
python examples/run_expt.py --dataset fmow --algorithm DANN --unlabeled_split test_unlabeled --root_dir data

I get the following exeption:

Traceback (most recent call last):
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 491, in <module>
    main()
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/run_expt.py", line 454, in main
    train(
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 114, in train
    run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True, unlabeled_dataset=unlabeled_dataset)
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/train.py", line 38, in run_epoch
    unlabeled_data_iterator = InfiniteDataIterator(unlabeled_dataset['loader'])
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/examples/utils.py", line 393, in __init__
    self.iter = iter(self.data_loader)
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 442, in __iter__
    return self._get_iterator()
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 388, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1085, in __init__
    self._reset(loader, first_iter=True)
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1118, in _reset
    self._try_put_index()
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1352, in _try_put_index
    index = self._next_index()
  File "/home/fs01/jyf6/miniconda3/envs/ponds/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 624, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "/mnt/beegfs/bulk/mirror/jyf6/datasets/wilds/wilds/common/data_loaders.py", line 131, in __iter__
    groups_for_batch = np.random.choice(
  File "mtrand.pyx", line 984, in numpy.random.mtrand.RandomState.choice
ValueError: Cannot take a larger sample than population when 'replace=False'

I think this occurs because there are only 2 unique years in the test_unlabeled split, but unlabeled_n_groups_per_batch is set to 8, so it tries to sample 8 years without replacement.

I was able to fix this by changing the argument unlabeled_n_groups_per_batch to 2, here: https://github.com/p-lambda/wilds/blob/main/examples/configs/datasets.py#L220

It would be great if this can be fixed. Thank you so much for releasing these wonderful datasets and baseline algorithms!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions