-
Notifications
You must be signed in to change notification settings - Fork 41
Gan implementation first pass #160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1162,12 +1162,15 @@ def __init__(self, job_dir: str, job_identifier: str): | |
| # * Semantic segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' | ||
| # * Instance segmentation: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' | ||
| # * Detection: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'unext_v1', 'unext_v2' | ||
| # * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2' | ||
| # * Denoising: 'unet', 'resunet', 'resunet++', 'attention_unet', 'seunet', 'resunet_se', 'unext_v1', 'unext_v2', 'nafnet' | ||
| # * Super-resolution: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'multiresunet', 'unext_v1', 'unext_v2' | ||
| # * Self-supervision: 'unet', 'resunet', 'resunet++', 'attention_unet', 'multiresunet', 'seunet', 'resunet_se', 'unetr', 'edsr', 'rcan', 'dfcan', 'wdsr', 'vit', 'mae', 'unext_v1', 'unext_v2' | ||
| # * Classification: 'simple_cnn', 'vit', 'efficientnet_b[0-7]' (only 2D) | ||
| # * Image to image: 'edsr', 'rcan', 'dfcan', 'wdsr', 'unet', 'resunet', 'resunet++', 'seunet', 'resunet_se', 'attention_unet', 'unetr', 'multiresunet', 'unext_v1', 'unext_v2' | ||
| _C.MODEL.ARCHITECTURE = "unet" | ||
| # Architecture of the network. Possible values are: | ||
| # * 'patchgan' | ||
| _C.MODEL.ARCHITECTURE_D = "patchgan" | ||
| # Number of feature maps on each level of the network. | ||
| _C.MODEL.FEATURE_MAPS = [16, 32, 64, 128, 256] | ||
| # Values to make the dropout with. Set to 0 to prevent dropout. When using it with 'ViT' or 'unetr' | ||
|
|
@@ -1306,6 +1309,26 @@ def __init__(self, job_dir: str, job_identifier: str): | |
| # Whether to use a pretrained version of STUNet on ImageNet | ||
| _C.MODEL.STUNET.PRETRAINED = False | ||
|
|
||
| # NafNet | ||
| _C.MODEL.NAFNET = CN() | ||
| # Base number of channels (width) used in the first layer and base levels. | ||
| _C.MODEL.NAFNET.WIDTH = 16 | ||
| # Number of NAFBlocks stacked at the bottleneck (deepest level). | ||
| _C.MODEL.NAFNET.MIDDLE_BLK_NUM = 12 | ||
| # Number of NAFBlocks assigned to each downsampling level of the encoder. | ||
| _C.MODEL.NAFNET.ENC_BLK_NUMS = [2, 2, 4, 8] | ||
| # Number of NAFBlocks assigned to each upsampling level of the decoder. | ||
| _C.MODEL.NAFNET.DEC_BLK_NUMS = [2, 2, 2, 2] | ||
| # Channel expansion factor for the depthwise convolution within the gating unit. | ||
| _C.MODEL.NAFNET.DW_EXPAND = 2 | ||
| # Expansion factor for the hidden layer within the feed-forward network. | ||
| _C.MODEL.NAFNET.FFN_EXPAND = 2 | ||
|
|
||
| # Discriminator PATCHGAN | ||
| _C.MODEL.PATCHGAN = CN() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move |
||
| # Number of initial convolutional filters in the first layer of the discriminator. | ||
| _C.MODEL.PATCHGAN.BASE_FILTERS = 64 | ||
|
|
||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
| # Loss | ||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
@@ -1371,7 +1394,24 @@ def __init__(self, job_dir: str, job_identifier: str): | |
| _C.LOSS.CONTRAST.MEMORY_SIZE = 5000 | ||
| _C.LOSS.CONTRAST.PROJ_DIM = 256 | ||
| _C.LOSS.CONTRAST.PIXEL_UPD_FREQ = 10 | ||
|
|
||
|
|
||
| # Fine-grained GAN composition. Set any weight to 0.0 to disable that term. | ||
| # Used when LOSS.TYPE == "COMPOSED_GAN". | ||
| _C.LOSS.COMPOSED_GAN = CN() | ||
| # Weight for adversarial BCE term. | ||
| _C.LOSS.COMPOSED_GAN.LAMBDA_GAN = 1.0 | ||
| # Weight for L1 reconstruction term. | ||
| _C.LOSS.COMPOSED_GAN.LAMBDA_RECON = 10.0 | ||
| # Weight for MSE reconstruction term. | ||
| _C.LOSS.COMPOSED_GAN.DELTA_MSE = 0.0 | ||
| # Weight for VGG perceptual term. | ||
| _C.LOSS.COMPOSED_GAN.ALPHA_PERCEPTUAL = 0.0 | ||
| # Weight for SSIM term. | ||
| _C.LOSS.COMPOSED_GAN.GAMMA_SSIM = 1.0 | ||
|
|
||
| # Backward-compatible alias for previous naming. | ||
| _C.LOSS.GAN = CN() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the purpose of this? It shouldn't be necessary |
||
|
|
||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
| # Training phase | ||
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
@@ -1381,12 +1421,18 @@ def __init__(self, job_dir: str, job_identifier: str): | |
| _C.TRAIN.VERBOSE = False | ||
| # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" | ||
| _C.TRAIN.OPTIMIZER = "SGD" | ||
| # Optimizer to use. Possible values: "SGD", "ADAM" or "ADAMW" for GAN discriminator | ||
| _C.TRAIN.OPTIMIZER_D = "SGD" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that the more than one opt is used change |
||
| # Learning rate | ||
| _C.TRAIN.LR = 1.0e-4 | ||
| # Learning rate for GAN discriminator | ||
| _C.TRAIN.LR_D = 1.0e-4 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as with optimizers: |
||
| # Weight decay | ||
| _C.TRAIN.W_DECAY = 0.02 | ||
| # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers | ||
| _C.TRAIN.OPT_BETAS = (0.9, 0.999) | ||
| # Coefficients used for computing running averages of gradient and its square. Used in ADAM and ADAMW optmizers for GANS discriminator | ||
| _C.TRAIN.OPT_BETAS_D = (0.5, 0.999) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same: |
||
| # Batch size | ||
| _C.TRAIN.BATCH_SIZE = 2 | ||
| # If memory or # gpus is limited, use this variable to maintain the effective batch size, which is | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -246,7 +246,7 @@ def create_train_val_augmentors( | |
| dic["zflip"] = cfg.AUGMENTOR.ZFLIP | ||
| if cfg.PROBLEM.TYPE == "INSTANCE_SEG": | ||
| dic["instance_problem"] = True | ||
| elif cfg.PROBLEM.TYPE == "DENOISING": | ||
| elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.LOSS.TYPE != "COMPOSED_GAN": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should be checking the model, e.g. |
||
| dic["n2v"] = True | ||
| dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX | ||
| dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR | ||
|
|
@@ -293,7 +293,7 @@ def create_train_val_augmentors( | |
| ) | ||
| if cfg.PROBLEM.TYPE == "INSTANCE_SEG": | ||
| dic["instance_problem"] = True | ||
| elif cfg.PROBLEM.TYPE == "DENOISING": | ||
| elif cfg.PROBLEM.TYPE == "DENOISING" and cfg.LOSS.TYPE != "COMPOSED_GAN": | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as prev comment |
||
| dic["n2v"] = True | ||
| dic["n2v_perc_pix"] = cfg.PROBLEM.DENOISING.N2V_PERC_PIX | ||
| dic["n2v_manipulator"] = cfg.PROBLEM.DENOISING.N2V_MANIPULATOR | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ def prepare_optimizer( | |
| cfg: CN, | ||
| model_without_ddp: nn.Module | nn.parallel.DistributedDataParallel, | ||
| steps_per_epoch: int, | ||
| is_gan: bool = False, | ||
| ) -> Tuple[Optimizer, Scheduler | None]: | ||
| """ | ||
| Create and configure the optimizer and learning rate scheduler for the given model. | ||
|
|
@@ -33,57 +34,89 @@ def prepare_optimizer( | |
| ---------- | ||
| cfg : YACS CN object | ||
| Configuration object with optimizer and scheduler settings. | ||
| model_without_ddp : nn.Module or nn.parallel.DistributedDataParallel | ||
| model_without_ddp : nn.Module or nn.parallel.DistributedDataParallel or dict | ||
| The model to optimize. | ||
| steps_per_epoch : int | ||
| Number of steps (batches) per training epoch. | ||
| is_gan : bool, optional | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't need this extra argument. Also, the whole function is modified a lot. You can do it simpler: make a for with the optimizers requested and return a list of optimizers/lr_schedulers for each |
||
| Whether to create optimizer/scheduler pairs for GAN generator and discriminator. | ||
|
|
||
| Returns | ||
| ------- | ||
| optimizer : Optimizer | ||
| Configured optimizer for the model. | ||
| lr_scheduler : Scheduler or None | ||
| Configured learning rate scheduler, or None if not specified. | ||
| optimizer : Optimizer or dict | ||
| Configured optimizer for the model or dict with generator/discriminator optimizers in GAN mode. | ||
| lr_scheduler : Scheduler or None or dict | ||
| Configured scheduler for the model or dict with generator/discriminator schedulers in GAN mode. | ||
| """ | ||
| lr = cfg.TRAIN.LR if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR | ||
| opt_args = {} | ||
| if cfg.TRAIN.OPTIMIZER in ["ADAM", "ADAMW"]: | ||
| opt_args["betas"] = cfg.TRAIN.OPT_BETAS | ||
| optimizer = timm.optim.create_optimizer_v2( | ||
| model_without_ddp, | ||
| opt=cfg.TRAIN.OPTIMIZER, | ||
| lr=lr, | ||
| weight_decay=cfg.TRAIN.W_DECAY, | ||
| **opt_args, | ||
| ) | ||
| print(optimizer) | ||
|
|
||
| # Learning rate schedulers | ||
| lr_scheduler = None | ||
| if cfg.TRAIN.LR_SCHEDULER.NAME != "": | ||
| if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": | ||
| lr_scheduler = ReduceLROnPlateau( | ||
| optimizer, | ||
| patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, | ||
| factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, | ||
| min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, | ||
| ) | ||
| elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": | ||
| lr_scheduler = WarmUpCosineDecayScheduler( | ||
| lr=cfg.TRAIN.LR, | ||
| min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, | ||
| warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, | ||
| epochs=cfg.TRAIN.EPOCHS, | ||
| ) | ||
| elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": | ||
| lr_scheduler = OneCycleLR( | ||
| optimizer, | ||
| cfg.TRAIN.LR, | ||
| epochs=cfg.TRAIN.EPOCHS, | ||
| steps_per_epoch=steps_per_epoch, | ||
| ) | ||
|
|
||
| return optimizer, lr_scheduler | ||
| def _make_scheduler(optimizer: Optimizer, lr_value: float) -> Scheduler | None: | ||
| lr_scheduler = None | ||
| if cfg.TRAIN.LR_SCHEDULER.NAME != "": | ||
| if cfg.TRAIN.LR_SCHEDULER.NAME == "reduceonplateau": | ||
| lr_scheduler = ReduceLROnPlateau( | ||
| optimizer, | ||
| patience=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_PATIENCE, | ||
| factor=cfg.TRAIN.LR_SCHEDULER.REDUCEONPLATEAU_FACTOR, | ||
| min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, | ||
| ) | ||
| elif cfg.TRAIN.LR_SCHEDULER.NAME == "warmupcosine": | ||
| lr_scheduler = WarmUpCosineDecayScheduler( | ||
| lr=lr_value, | ||
| min_lr=cfg.TRAIN.LR_SCHEDULER.MIN_LR, | ||
| warmup_epochs=cfg.TRAIN.LR_SCHEDULER.WARMUP_COSINE_DECAY_EPOCHS, | ||
| epochs=cfg.TRAIN.EPOCHS, | ||
| ) | ||
| elif cfg.TRAIN.LR_SCHEDULER.NAME == "onecycle": | ||
| lr_scheduler = OneCycleLR( | ||
| optimizer, | ||
| lr_value, | ||
| epochs=cfg.TRAIN.EPOCHS, | ||
| steps_per_epoch=steps_per_epoch, | ||
| ) | ||
| return lr_scheduler | ||
|
|
||
| def _make_optimizer(model: nn.Module | nn.parallel.DistributedDataParallel, train_cfg: dict): | ||
| lr_value = train_cfg["lr"] | ||
| opt_name = train_cfg["optimizer"] | ||
| betas = train_cfg["betas"] | ||
| w_decay = train_cfg["weight_decay"] | ||
|
|
||
| lr = lr_value if cfg.TRAIN.LR_SCHEDULER.NAME != "warmupcosine" else cfg.TRAIN.LR_SCHEDULER.MIN_LR | ||
| opt_args = {} | ||
| if opt_name in ["ADAM", "ADAMW"]: | ||
| opt_args["betas"] = betas | ||
|
|
||
| optimizer = timm.optim.create_optimizer_v2( | ||
| model, | ||
| opt=opt_name, | ||
| lr=lr, | ||
| weight_decay=w_decay, | ||
| **opt_args, | ||
| ) | ||
| print(optimizer) | ||
| lr_scheduler = _make_scheduler(optimizer, lr_value) | ||
| return optimizer, lr_scheduler | ||
|
|
||
| g_train_cfg = { | ||
| "lr": cfg.TRAIN.LR, | ||
| "optimizer": cfg.TRAIN.OPTIMIZER, | ||
| "betas": cfg.TRAIN.OPT_BETAS, | ||
| "weight_decay": cfg.TRAIN.W_DECAY, | ||
| } | ||
|
|
||
| if not is_gan: | ||
| return _make_optimizer(model_without_ddp, g_train_cfg) | ||
|
|
||
| d_train_cfg = { | ||
| "lr": cfg.TRAIN.LR_D, | ||
| "optimizer": cfg.TRAIN.OPTIMIZER_D, | ||
| "betas": cfg.TRAIN.OPT_BETAS_D, | ||
| "weight_decay": cfg.TRAIN.W_DECAY, | ||
| } | ||
|
|
||
| optimizer_g, scheduler_g = _make_optimizer(model_without_ddp["generator"], g_train_cfg) | ||
| optimizer_d, scheduler_d = _make_optimizer(model_without_ddp["discriminator"], d_train_cfg) | ||
|
|
||
| return {"generator": optimizer_g, "discriminator": optimizer_d}, {"generator": scheduler_g, "discriminator": scheduler_d,} | ||
|
|
||
|
|
||
| def build_callbacks(cfg: CN) -> EarlyStopping | None: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a feature only of nafnet so introduce it inside
MODEL.NAFNET