Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions robustbench/loaders.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This file is based on the code from https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py.
"""
import pkg_resources
from importlib.resources import files, as_file

from torchvision.datasets.vision import VisionDataset

Expand All @@ -17,8 +17,10 @@


def make_custom_dataset(root, path_imgs, class_to_idx):
with open(pkg_resources.resource_filename(__name__, path_imgs), 'r') as f:
fnames = f.readlines()
resource = files(__package__) / path_imgs
with as_file(resource) as file_path:
with open(file_path, 'r') as f:
fnames = f.readlines()
images = [(os.path.join(root,
c.split('\n')[0]), class_to_idx[c.split('/')[0]])
for c in fnames]
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="robustbench",
version="1.1",
version="1.1.1",
author="Francesco Croce, Maksym Andriushchenko, Vikash Sehwag, Edoardo Debenedetti",
author_email="adversarial.benchmark@gmail.com",
description="This package provides the data for RobustBench together with the model zoo.",
Expand Down
36 changes: 36 additions & 0 deletions tests/custom_loader_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import robustbench
from torchvision.datasets.vision import VisionDataset

import torch
import torch.utils.data as data
import torchvision.transforms as transforms

from PIL import Image

import os
import os.path
import sys

from robustbench import data
from robustbench import loaders

data_dir = '~/imagenet/val'

imagenet = loaders.CustomDatasetFolder(data_dir, robustbench.loaders.default_loader, transform=
transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor()]))

torch.manual_seed(0)

test_loader = data.data.DataLoader(imagenet,
batch_size=50,
shuffle=True,
num_workers=3)

x, y, path = next(iter(test_loader))

with open('path_imgs_2.txt', 'w') as f:
f.write('\n'.join(path))
f.flush()