-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathget.py
More file actions
265 lines (223 loc) · 8.86 KB
/
get.py
File metadata and controls
265 lines (223 loc) · 8.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import torch.utils.data
import os
from pprint import pprint
# Ensuring CUDA errors are reported in the main process
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
from mp.data.datasets.ds_prostate import Prostate
from mp.data.datasets.ds_hippocampus import Hippocampus
from mp.data.datasets.ds_cardiac_mm import Cardiac
from mp.data.datasets.ds_optic import Optic
from mp.data.datasets.ds_mr_brain import Brain
from mp.data.datasets.ds_polyp import Polyp
from mp.data.data import Data
from mp.data.pytorch.pytorch_seg_dataset import PytorchSeg2DDataset
from torch.utils.data import DataLoader
from mp.models.continual.kd import KD
from mp.models.continual.mas import MAS
from mp.models.continual.tkrl import TKRL
from mp.eval.losses.losses_segmentation import LossDiceBCE
from mp.agents.kd_agent import KDAgent
from mp.agents.mas_agent import MASAgent
from mp.agents.ewc_agent import EWCAgent
from mp.agents.mib_agent import MIBAgent
from mp.agents.plop_agent import PLOPAgent
from mp.agents.seq_agent import SEQAgent
from mp.agents.ted_agent import TEDAgent
from mp.agents.tkrl_agent import TKRLAgent
from mp.agents.pcd_agent import PCDAgent
from mp.agents.vma_agent import VMAAgent
from mp.agents.rmae_agent import RMAEAgent
def get_dataset(config, exp):
"""
Initializes and returns the dataset for the given configuration.
Args:
config (dict): Configuration dictionary with dataset and training details.
exp (object): Experiment object for setting data splits.
Returns:
tuple: Contains training and test dataloaders, datasets, experiment run, and label details.
"""
data = Data()
subset_list = []
# Load datasets based on configuration
if config["dataset"] == "brain":
subset_list = ["t2", "t1ce", "flair"]
for name in subset_list:
dataset_domain = Brain(subset=name)
dataset_domain.name = name
data.add_dataset(dataset_domain)
elif config["dataset"] == "polyp":
subset_list = ["C1", "C2", "C3", "C4", "C5", "C6"]
for name in subset_list:
dataset_domain = Polyp(subset=name)
dataset_domain.name = name
data.add_dataset(dataset_domain)
elif config["dataset"] == "prostate":
subset_list = ["RUNMC", "BMC", "I2CVB", "UCL", "BIDMC", "HK"]
for name in subset_list:
dataset_domain = Prostate(subset=name)
dataset_domain.name = name
data.add_dataset(dataset_domain)
elif config["dataset"] == "hippocampus":
subset_list = ["DecathlonHippocampus", "DryadHippocampus", "HarP"]
for name in subset_list:
dataset_domain = Hippocampus(subset=name)
dataset_domain.name = name
data.add_dataset(dataset_domain)
elif config["dataset"] == "mm":
subset_list = ["Siemens", "Philips", "GE", "Canon"]
target = {"i": 1, "o": 2, "r": 3}
for name in subset_list:
dataset_domain = Cardiac(subset=name, target=target[config["target_class"]])
dataset_domain.name = name
data.add_dataset(dataset_domain)
elif config["dataset"] == "optic":
subset_list = ["Domain1", "Domain2", "Domain3", "Domain4"]
target = {"i": 1, "o": 2}
for name in subset_list:
dataset_domain = Optic(subset=name, target=target[config["target_class"]])
dataset_domain.name = name
data.add_dataset(dataset_domain)
exp.set_data_splits(data)
exp_run = exp.get_run(0, reload_exp_run=(config["resume_epoch"] is not None))
datasets = {}
# Prepare data loaders for each subset and split (train/test)
for dataset_name, dataset in data.datasets.items():
for split, data_indices in exp.splits[dataset_name][exp_run.run_ix].items():
data_indices = data_indices[:None] # Limit number of samples if debugging
if len(data_indices) > 0:
aug_type = config["augmentation"] if "test" not in split else "none"
datasets[(dataset_name, split)] = PytorchSeg2DDataset(
dataset=dataset,
ix_lst=data_indices,
size=config["input_shape"],
norm_key="rescaling",
aug_key=aug_type,
resize=(not config["no_resize"]),
channel_labels=True,
)
# for ds_name, ds in datasets.items():
# print(f"{ds_name}: {len(ds)}")
# for instance_ix, instance in enumerate(ds.instances):
# subject_name = instance.name
# print(f"{subject_name}: {subject_name}")
# Handle joint training approach separately
if config["approach"] in ["joint"]:
joint_train_dataset = torch.utils.data.ConcatDataset(datasets[(name, "train")] for name in subset_list)
joint_test_dataset = torch.utils.data.ConcatDataset(datasets[(name, "test")] for name in subset_list)
train_dataloader = DataLoader(
dataset=joint_train_dataset,
batch_size=config["batch_size"],
shuffle=True,
drop_last=False,
pin_memory=True,
num_workers=len(config["device_ids"]) * config["n_workers"],
)
test_dataloader = DataLoader(
dataset=joint_test_dataset,
batch_size=config["batch_size"],
shuffle=False,
drop_last=False,
pin_memory=True,
num_workers=len(config["device_ids"]) * config["n_workers"],
)
return (
[train_dataloader],
[test_dataloader],
datasets,
exp_run,
{"label_nr": data.nr_labels, "label_names": data.label_names},
)
# Prepare dataloaders for individual training and test datasets
train_dataloaders = []
test_dataloaders = []
for subset_name in subset_list:
train_dataloaders.append(
DataLoader(
dataset=datasets[(subset_name, "train")],
batch_size=config["batch_size"],
shuffle=True,
drop_last=False,
pin_memory=True,
num_workers=len(config["device_ids"]) * config["n_workers"],
)
)
test_dataloaders.append(
DataLoader(
dataset=datasets[(subset_name, "test")],
batch_size=config["batch_size"],
shuffle=False,
drop_last=False,
pin_memory=True,
num_workers=len(config["device_ids"]) * config["n_workers"],
)
)
return (
train_dataloaders,
test_dataloaders,
datasets,
exp_run,
{"label_nr": data.nr_labels, "label_names": data.label_names},
)
def get_model(config, nr_labels):
"""
Initializes and returns the model for the given approach.
Args:
config (dict): Configuration dictionary with model and approach details.
nr_labels (int): Number of labels in the dataset.
Returns:
torch.nn.Module: Initialized model.
"""
model_mapping = {
"mas": MAS,
"ewc": MAS,
"kd": KD,
"mib": KD,
"plop": KD,
"seq": MAS,
"joint": MAS,
"ted": KD,
"tkrl": TKRL,
"pcd": TKRL,
"vma": TKRL,
"rmae": TKRL,
}
model_class = model_mapping[config["approach"]]
model = model_class(input_shape=config["input_shape"], nr_labels=nr_labels)
model.to(config["device"])
return model
def get_loss_type(config):
"""
Returns the loss function to be used for training.
Args:
config (dict): Configuration dictionary with loss details.
Returns:
Loss function object.
"""
return LossDiceBCE(bce_weight=1.0, smooth=1.0, device=config["device"])
def get_agent(config, model, label_names):
"""
Initializes and returns the training agent based on the approach.
Args:
config (dict): Configuration dictionary with agent details.
model (torch.nn.Module): Model to be used by the agent.
label_names (list): List of label names in the dataset.
Returns:
Agent object: Initialized agent for training.
"""
agent_mapping = {
"mas": MASAgent,
"ewc": EWCAgent,
"kd": KDAgent,
"mib": MIBAgent,
"plop": PLOPAgent,
"seq": SEQAgent,
"joint": SEQAgent,
"ted": TEDAgent,
"tkrl": TKRLAgent,
"pcd": PCDAgent,
"vma": VMAAgent,
"rmae": RMAEAgent,
}
agent_class = agent_mapping[config["approach"]]
agent = agent_class(model=model, label_names=label_names, device=config["device"])
return agent