|
class CNNBackbone(ModelABC): |
|
"""Retrieve the model backbone and strip the classification layer. |
|
|
|
This is a wrapper for pretrained models within pytorch. |
|
|
|
Args: |
|
backbone (str): |
|
Model name. Currently, the tool supports following |
|
model names and their default associated weights from pytorch. |
|
- "alexnet" |
|
- "resnet18" |
|
- "resnet34" |
|
- "resnet50" |
|
- "resnet101" |
|
- "resnext50_32x4d" |
|
- "resnext101_32x8d" |
|
- "wide_resnet50_2" |
|
- "wide_resnet101_2" |
|
- "densenet121" |
|
- "densenet161" |
|
- "densenet169" |
|
- "densenet201" |
|
- "inception_v3" |
|
- "googlenet" |
|
- "mobilenet_v2" |
|
- "mobilenet_v3_large" |
|
- "mobilenet_v3_small" |
|
|
|
Examples: |
|
>>> # Creating resnet50 architecture from default pytorch |
|
>>> # without the classification layer with its associated |
|
>>> # weights loaded |
|
>>> model = CNNBackbone(backbone="resnet50") |
|
>>> model.eval() # set to evaluation mode |
|
>>> # dummy sample in NHWC form |
|
>>> samples = torch.rand(4, 3, 512, 512) |
|
>>> features = model(samples) |
|
>>> features.shape # features after global average pooling |
|
torch.Size([4, 2048]) |
|
|
|
""" |
|
|
|
def __init__(self: CNNBackbone, backbone: str) -> None: |
|
"""Initialize :class:`CNNBackbone`.""" |
|
super().__init__() |
|
self.feat_extract = _get_architecture(backbone) |
|
self.pool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
|
# pylint: disable=W0221 |
|
# because abc is generic, this is actual definition |
|
def forward(self: CNNBackbone, imgs: torch.Tensor) -> torch.Tensor: |
|
"""Pass input data through the model. |
|
|
|
Args: |
|
imgs (torch.Tensor): |
|
Model input. |
|
|
|
""" |
|
feat = self.feat_extract(imgs) |
|
gap_feat = self.pool(feat) |
|
return torch.flatten(gap_feat, 1) |
|
|
|
@staticmethod |
|
def infer_batch( |
|
model: nn.Module, |
|
batch_data: torch.Tensor, |
|
*, |
|
on_gpu: bool, |
|
) -> list[np.ndarray, ...]: |
|
"""Run inference on an input batch. |
|
|
|
Contains logic for forward operation as well as i/o aggregation. |
|
|
|
Args: |
|
model (nn.Module): |
|
PyTorch defined model. |
|
batch_data (torch.Tensor): |
|
A batch of data generated by |
|
`torch.utils.data.DataLoader`. |
|
on_gpu (bool): |
|
Whether to run inference on a GPU. |
|
|
|
""" |
|
img_patches_device = batch_data.to(select_device(on_gpu=on_gpu)).type( |
|
torch.float32, |
|
) # to NCHW |
|
img_patches_device = img_patches_device.permute(0, 3, 1, 2).contiguous() |
|
|
|
# Inference mode |
|
model.eval() |
|
# Do not compute the gradient (not training) |
|
with torch.inference_mode(): |
|
output = model(img_patches_device) |
|
# Output should be a single tensor or scalar |
|
return [output.cpu().numpy()] |
Description
I think it would be useful to integrate pre-trained foundation models from other labs into
tiatoolbox.models.architecture.vanilla.py.Currently, the
_get_architecture()function allows the use of models fromtorchvision.models.But another function
_get_timm_architecture()could be made to incorporate foundation models which are available fromtimmwith weights on HuggingFace Hub. All the models fromtimethat I've used require users to sign the licence agreement with the authors, so the licencing question seems to be solved itself since there is no way users will get access to the model weights just through Tiatoolbox without getting the access request approved by the authors first.What I Did
To add them myself, I copied de definition of CNNBackbone changing
self.feat_extract = _get_timm_architecture(backbone)(batch_size, embedding_size)tiatoolbox/tiatoolbox/models/architecture/vanilla.py
Lines 176 to 270 in 015652c
Suggestion
Would you be interested in adding this functionality? If yes, I can make a pull request.