-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathargs.py
More file actions
156 lines (134 loc) · 5.42 KB
/
args.py
File metadata and controls
156 lines (134 loc) · 5.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import argparse
import torch
# parse train options
def _get_parser():
"""
Creates an argument parser for training options.
Returns:
argparse.ArgumentParser: The parser containing training options.
"""
parser = argparse.ArgumentParser(description="Training Configuration Parser")
# General arguments
parser.add_argument(
"--experiment-name",
type=str,
default="",
help="Name of the experiment for tracking or resuming training",
)
parser.add_argument("--nr-runs", type=int, default=1, help="Number of training runs")
parser.add_argument("--seed", type=int, default=7, help="Random seed. default 7")
# Hardware configuration
parser.add_argument("--device", type=str, default="cuda", help="Device type, either 'cpu' or 'cuda'")
parser.add_argument("--device-ids", nargs="+", type=int, default=[5], help="IDs of GPU devices to use")
parser.add_argument(
"--n-workers",
type=int,
default=8,
help="Number of workers per GPU for data loading",
)
# Dataset-related arguments
parser.add_argument("--dataset", type=str, default="prostate", help="Dataset to be used for training")
parser.add_argument(
"--target-class",
type=str,
default="i",
help="Target class for segmentation: i, o, r",
)
parser.add_argument(
"--test-ratio",
type=float,
default=0.2,
help="Ratio of data to be used for testing",
)
parser.add_argument(
"--val-ratio",
type=float,
default=0.1,
help="Ratio of data to be used for validation",
)
parser.add_argument("--input-dim-channels", type=int, default=3, help="Number of input image channels")
parser.add_argument("--input-dim-size", type=int, default=192, help="Height and width of input images")
parser.add_argument(
"--no-resize",
action="store_true",
help="Flag to indicate if images should not be resized",
)
parser.add_argument("--augmentation", type=str, default="none", help="Type of data augmentation to apply")
# Training parameters
parser.add_argument("--epochs", type=int, default=5, help="Number of training epochs")
parser.add_argument("--batch-size", type=int, default=16, help="Batch size for training")
parser.add_argument("--lr", type=float, default=2e-4, help="Learning rate for the first training stage")
parser.add_argument("--lr-2", type=float, default=1e-4, help="Learning rate for the second stage")
parser.add_argument("--approach", type=str, default="seq", help="Training approach to use")
parser.add_argument("--backbone", type=str, default="segformer", help="Model backbone architecture")
parser.add_argument(
"--val-best",
action="store_true",
default=True,
help="Use best validation performance to progress to the next domain",
)
# Resuming training
parser.add_argument(
"--resume-from",
type=str,
default=None,
help="Path to the directory where the model checkpoints are stored, such as optici-seq",
)
parser.add_argument(
"--resume-epoch",
type=int,
default=None,
help="Epoch number to resume training from, -1 for the latest checkpoint",
)
# Loss configuration
parser.add_argument("--loss-type", type=str, default="dice_bce", help="Type of loss function to use for training")
parser.add_argument("--lambda-d", type=float, default=0.001, help="Weight for tuning the sub-loss")
# hyperparameters
parser.add_argument("--multiply-probes", type=float, default=1, help="n batch of probes")
parser.add_argument("--boundary", type=float, default=0.99, help="Weight for boundary")
parser.add_argument("--mask-ratio", type=float, default=0.1, help="Mask ratio for the masks")
return parser
def parse_args(argv):
"""
Parses command-line arguments.
Args:
argv (list): List of arguments passed from the command line.
Returns:
Namespace: Parsed arguments as an object with attribute-style access.
"""
parser = _get_parser()
args = parser.parse_args(argv)
# Determine device based on availability
args.device = str(
args.device + ":" + str(args.device_ids[0]) if torch.cuda.is_available() and args.device == "cuda" else "cpu"
)
device_name = str(torch.cuda.get_device_name(args.device) if args.device == "cuda" else args.device)
print(f"Device name: {device_name}")
args.input_shape = (args.input_dim_channels, args.input_dim_size, args.input_dim_size)
return args
def parse_args_as_dict(argv):
"""
Parses command-line arguments and returns them as a dictionary.
Args:
argv (list): List of arguments passed from the command line.
Returns:
dict: Arguments represented as a dictionary.
"""
return vars(parse_args(argv))
def parse_dict_as_args(dictionary):
"""
Converts a dictionary of arguments into a command-line style argument list.
Args:
dictionary (dict): Dictionary of argument names and values.
Returns:
Namespace: Parsed arguments as an object with attribute-style access.
"""
argv = []
for key, value in dictionary.items():
if isinstance(value, bool):
if value:
argv.append(f"--{key}")
else:
argv.append(f"--{key}")
argv.append(str(value))
return parse_args(argv)