-
Notifications
You must be signed in to change notification settings - Fork 213
Description
I have implemented partialconv, and stumbled with the problem that layer activations are peaking at edges, though "Partial Convolution based Padding" paper at Figure 5 (paper) explicitly saying that "Red rectangles show the strong activation regions from VGG19 network with zero paddding":

I started to double check my implementation, and it turns out to be similar as this repo. After that I started to think about it, why this is happening. After trial and fail I came up with simple solution - just convolute mask on mask_weight, then normalize mask by dividing it with max value in the mask.
Here is code for your reference to double check your implementation, my implementation, and fix by yourself:
Code
from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor
class PartialConv2d(nn.Conv2d):
def __init__(self, *args, **kwargs):
# whether the mask is multi-channel or not
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False
if 'return_mask' in kwargs:
self.return_mask = kwargs['return_mask']
kwargs.pop('return_mask')
else:
self.return_mask = False
super(PartialConv2d, self).__init__(*args, **kwargs)
if self.multi_channel:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(self.out_channels, self.in_channels,
self.kernel_size[0], self.kernel_size[1]))
else:
self.register_buffer(name='weight_maskUpdater', persistent=False,
tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]
self.last_size = (None, None, None, None)
self.update_mask = None
self.mask_ratio = None
def forward(self, input, mask_in=None):
assert len(input.shape) == 4
if mask_in is not None or self.last_size != tuple(input.shape):
self.last_size = tuple(input.shape)
with torch.no_grad():
if mask_in is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
else:
mask = mask_in
self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)
# for mixed precision training, change 1e-8 to 1e-6
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
# self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)
if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)
if self.return_mask:
return output, self.update_mask
else:
return output
class MaskedConv2d(nn.Conv2d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
eps=1e-8,
multichannel: bool = False,
partial_conv: bool = False,
device=None,
dtype=None
) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
if multichannel:
self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
else:
self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
self.eps = eps
self.multichannel = multichannel
self.partial_conv = partial_conv
def get_mask(
self,
input: torch.Tensor,
mask: torch.Tensor | None
) -> (torch.Tensor, torch.Tensor):
if mask is None:
if self.multichannel:
mask = torch.ones_like(input)
else:
mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
else:
if self.multichannel:
mask = mask.expand_as(input)
else:
mask = mask.expand(1, 1, *input.shape[2:])
return mask
def forward(
self,
input: torch.Tensor,
mask: torch.Tensor | None = None
) -> (torch.Tensor, torch.Tensor | None):
if mask is not None:
input *= mask
mask = self.get_mask(input, mask)
if self.partial_conv:
output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
mask_ratio = mask_kernel_numel / (mask + self.eps)
mask.clamp_(0, 1)
# Apply re-weighting and bias
output *= mask_ratio
if self.bias is not None:
output += self.bias.view(-1, 1, 1)
output *= mask
else:
output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)
max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
mask = mask / max_vals
return output, mask
def extra_repr(self):
return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"
class MaskedPixelUnshuffle(nn.PixelUnshuffle):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
return super().forward(input), super().forward(mask) if mask is not None else None
class MaskedSequential(nn.Sequential):
def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
for module in self:
input, mask = module(input, mask)
return input, mask
@contextmanager
def register_hooks(
model: torch.nn.Module,
hook: Callable,
predicate: Callable[[str, torch.nn.Module], bool],
**hook_kwargs
):
handles = []
try:
for name, module in model.named_modules():
if predicate(name, module):
hook: Callable = partial(hook, name=name, **hook_kwargs)
handle = module.register_forward_hook(hook)
handles.append(handle)
yield handles
finally:
for handle in handles:
handle.remove()
def activations_recorder_hook(
module: torch.nn.Module,
input: torch.Tensor,
output: torch.Tensor,
name: str,
*,
storage: dict[str, Any]
):
if name in storage:
if isinstance(storage[name], list):
storage[name].append(output)
else:
storage[name] = [storage[name], output]
else:
storage[name] = output
def forward_with_activations(
model: torch.nn.Module,
predicate: Callable[[str, torch.nn.Module], bool],
*model_args,
**model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
storage = {}
with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
output = model(*model_args, **model_kwargs)
return output, storage
def test_it():
torch.manual_seed(37)
in_channels = 3
downscale_factor = 2
scale = 1
base = 2
depth = 8
visualize_depth = 4
eps = 1e-8
pconv = []
for i in range(depth):
pconv.append(MaskedPixelUnshuffle(downscale_factor))
pconv.append(PartialConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
)
pconv = MaskedSequential(*pconv)
mpconv = []
for i in range(depth):
mpconv.append(MaskedPixelUnshuffle(downscale_factor))
mpconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
)
mpconv = MaskedSequential(*mpconv)
mconv = []
for i in range(depth):
mconv.append(MaskedPixelUnshuffle(downscale_factor))
mconv.append(MaskedConv2d(
in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
out_channels=scale * base ** i * downscale_factor ** 2,
kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
)
mconv = MaskedSequential(*mconv)
with torch.no_grad():
print(f"{pconv=}")
print(f"{mpconv=}")
print(f"{mconv=}")
print(f"{list(pconv.state_dict().keys())=}")
print(f"{list(mpconv.state_dict().keys())=}")
print(f"{list(mconv.state_dict().keys())=}")
mpconv.load_state_dict(pconv.state_dict())
mconv.load_state_dict(pconv.state_dict())
x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x)
def is_conv_predicate(name: str, module: torch.nn.Module):
return isinstance(module, torch.nn.Conv2d)
(y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv)
(y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv)
(y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv)
assert torch.allclose(y_mpconv, y_pconv)
assert not torch.allclose(y_mconv, y_mpconv)
print(f"{activations_pconv.keys()=}") # ['1', '3', '5', '7', '9', '11', '13', '15']
# fig, axs = plt.subplots(nrows=visualize_depth, ncols=3, figsize=(12, 8), dpi=180)
fig, axs = plt.subplots(nrows=3, ncols=visualize_depth, figsize=(12, 8), dpi=180)
axs = axs.flatten()
for impl_i, (name, y, mask, activations) in enumerate([
("pconv", y_pconv, mask_pconv, activations_pconv),
("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
("mconv", y_mconv, mask_mconv, activations_mconv)
]):
batch_i = 0
for depth_i in range(visualize_depth):
# ax = axs[depth_i * 3 + impl_i]
ax = axs[impl_i * visualize_depth + depth_i]
output = activations[f"{depth_i * 2 + 1}"][0][batch_i]
mask_output = activations[f"{depth_i * 2 + 1}"][1][batch_i]
mean = output.mean()
std = output.std(unbiased=False)
skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")
ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
ax.set_title(f"{name} {depth_i=}")
ax.axis('off')
# plt.suptitle(f"Depth {depth_i}")
plt.show()
if __name__ == '__main__':
test_it()Output:
name='pconv', depth_i=0, mean=tensor(-0.0040), std=tensor(0.5844), skewness=tensor(0.0056), kurtosis=tensor(3.0593)
name='pconv', depth_i=1, mean=tensor(-0.0014), std=tensor(0.3347), skewness=tensor(-0.0053), kurtosis=tensor(3.1046)
name='pconv', depth_i=2, mean=tensor(-0.0001), std=tensor(0.1993), skewness=tensor(0.0125), kurtosis=tensor(3.2002)
name='pconv', depth_i=3, mean=tensor(-0.0013), std=tensor(0.1211), skewness=tensor(-0.0061), kurtosis=tensor(3.5512)
name='mpconv', depth_i=0, mean=tensor(-0.0040), std=tensor(0.5844), skewness=tensor(0.0056), kurtosis=tensor(3.0593)
name='mpconv', depth_i=1, mean=tensor(-0.0014), std=tensor(0.3347), skewness=tensor(-0.0053), kurtosis=tensor(3.1046)
name='mpconv', depth_i=2, mean=tensor(-0.0001), std=tensor(0.1993), skewness=tensor(0.0125), kurtosis=tensor(3.2002)
name='mpconv', depth_i=3, mean=tensor(-0.0013), std=tensor(0.1211), skewness=tensor(-0.0061), kurtosis=tensor(3.5512)
name='mconv', depth_i=0, mean=tensor(-0.0039), std=tensor(0.5769), skewness=tensor(0.0052), kurtosis=tensor(3.0468)
name='mconv', depth_i=1, mean=tensor(-0.0016), std=tensor(0.3209), skewness=tensor(-0.0099), kurtosis=tensor(3.0444)
name='mconv', depth_i=2, mean=tensor(-0.0003), std=tensor(0.1796), skewness=tensor(-0.0102), kurtosis=tensor(3.1047)
name='mconv', depth_i=3, mean=tensor(-0.0011), std=tensor(0.0973), skewness=tensor(-0.0421), kurtosis=tensor(3.3349)
pconvis an original implementation of partial conv (this repo)mpconvis my implementation of partial convmconvis my approach of masked convolution
Here is also activations on real images:


