Skip to content

Map at edges is peaking (PartialConv2d implementation + fix) #44

@ivanstepanovftw

Description

@ivanstepanovftw

I have implemented partialconv, and stumbled with the problem that layer activations are peaking at edges, though "Partial Convolution based Padding" paper at Figure 5 (paper) explicitly saying that "Red rectangles show the strong activation regions from VGG19 network with zero paddding":
image


I started to double check my implementation, and it turns out to be similar as this repo. After that I started to think about it, why this is happening. After trial and fail I came up with simple solution - just convolute mask on mask_weight, then normalize mask by dividing it with max value in the mask.

Here is code for your reference to double check your implementation, my implementation, and fix by yourself:

Code

from contextlib import contextmanager
from functools import partial
from typing import Tuple, Any, Callable

import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from torch import nn, Tensor


class PartialConv2d(nn.Conv2d):
    def __init__(self, *args, **kwargs):

        # whether the mask is multi-channel or not
        if 'multi_channel' in kwargs:
            self.multi_channel = kwargs['multi_channel']
            kwargs.pop('multi_channel')
        else:
            self.multi_channel = False

        if 'return_mask' in kwargs:
            self.return_mask = kwargs['return_mask']
            kwargs.pop('return_mask')
        else:
            self.return_mask = False

        super(PartialConv2d, self).__init__(*args, **kwargs)

        if self.multi_channel:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(self.out_channels, self.in_channels,
                                                   self.kernel_size[0], self.kernel_size[1]))
        else:
            self.register_buffer(name='weight_maskUpdater', persistent=False,
                                 tensor=torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]))

        self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3]

        self.last_size = (None, None, None, None)
        self.update_mask = None
        self.mask_ratio = None

    def forward(self, input, mask_in=None):
        assert len(input.shape) == 4
        if mask_in is not None or self.last_size != tuple(input.shape):
            self.last_size = tuple(input.shape)

            with torch.no_grad():
                if mask_in is None:
                    # if mask is not provided, create a mask
                    if self.multi_channel:
                        mask = torch.ones_like(input)
                    else:
                        mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], device=input.device, dtype=input.dtype)
                else:
                    mask = mask_in

                self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

                # for mixed precision training, change 1e-8 to 1e-6
                self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
                # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
                self.update_mask = torch.clamp(self.update_mask, 0, 1)
                self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

        raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

        if self.bias is not None:
            bias_view = self.bias.view(1, self.out_channels, 1, 1)
            output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
            output = torch.mul(output, self.update_mask)
        else:
            output = torch.mul(raw_out, self.mask_ratio)

        if self.return_mask:
            return output, self.update_mask
        else:
            return output


class MaskedConv2d(nn.Conv2d):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size,
            stride=1,
            padding=0,
            dilation=1,
            groups: int = 1,
            bias: bool = True,
            padding_mode: str = 'zeros',
            eps=1e-8,
            multichannel: bool = False,
            partial_conv: bool = False,
            device=None,
            dtype=None
    ) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
        if multichannel:
            self.register_buffer('mask_weight', torch.ones(out_channels, self.in_channels // groups, *self.kernel_size, **factory_kwargs), persistent=False)
        else:
            self.register_buffer('mask_weight', torch.ones(1, 1, *self.kernel_size, **factory_kwargs), persistent=False)
        self.eps = eps
        self.multichannel = multichannel
        self.partial_conv = partial_conv

    def get_mask(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None
    ) -> (torch.Tensor, torch.Tensor):
        if mask is None:
            if self.multichannel:
                mask = torch.ones_like(input)
            else:
                mask = torch.ones(1, 1, *input.shape[2:], device=input.device, dtype=input.dtype)
        else:
            if self.multichannel:
                mask = mask.expand_as(input)
            else:
                mask = mask.expand(1, 1, *input.shape[2:])
        return mask

    def forward(
            self,
            input: torch.Tensor,
            mask: torch.Tensor | None = None
    ) -> (torch.Tensor, torch.Tensor | None):
        if mask is not None:
            input *= mask

        mask = self.get_mask(input, mask)

        if self.partial_conv:
            output = F.conv2d(input, self.weight, None, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            mask_kernel_numel = self.mask_weight.data.shape[1:].numel()
            mask_ratio = mask_kernel_numel / (mask + self.eps)
            mask.clamp_(0, 1)

            # Apply re-weighting and bias
            output *= mask_ratio
            if self.bias is not None:
                output += self.bias.view(-1, 1, 1)

            output *= mask
        else:
            output = F.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

            mask = F.conv2d(mask, self.mask_weight, None, self.stride, self.padding, self.dilation, self.groups if self.multichannel else 1)

            max_vals = mask.max(dim=2, keepdim=True)[0].max(dim=3, keepdim=True)[0]
            mask = mask / max_vals

        return output, mask

    def extra_repr(self):
        return f"{super().extra_repr()}, eps={self.eps}, multichannel={self.multichannel}, partial_conv={self.partial_conv}"


class MaskedPixelUnshuffle(nn.PixelUnshuffle):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        return super().forward(input), super().forward(mask) if mask is not None else None


class MaskedSequential(nn.Sequential):
    def forward(self, input: Tensor, mask: Tensor | None = None) -> (Tensor, Tensor | None):
        for module in self:
            input, mask = module(input, mask)
        return input, mask


@contextmanager
def register_hooks(
        model: torch.nn.Module,
        hook: Callable,
        predicate: Callable[[str, torch.nn.Module], bool],
        **hook_kwargs
):
    handles = []
    try:
        for name, module in model.named_modules():
            if predicate(name, module):
                hook: Callable = partial(hook, name=name, **hook_kwargs)
                handle = module.register_forward_hook(hook)
                handles.append(handle)
        yield handles
    finally:
        for handle in handles:
            handle.remove()


def activations_recorder_hook(
        module: torch.nn.Module,
        input: torch.Tensor,
        output: torch.Tensor,
        name: str,
        *,
        storage: dict[str, Any]
):
    if name in storage:
        if isinstance(storage[name], list):
            storage[name].append(output)
        else:
            storage[name] = [storage[name], output]
    else:
        storage[name] = output


def forward_with_activations(
        model: torch.nn.Module,
        predicate: Callable[[str, torch.nn.Module], bool],
        *model_args,
        **model_kwargs,
) -> Tuple[torch.Tensor, dict[str, Any]]:
    storage = {}
    with register_hooks(model, activations_recorder_hook, predicate, storage=storage):
        output = model(*model_args, **model_kwargs)
    return output, storage


def test_it():
    torch.manual_seed(37)

    in_channels = 3
    downscale_factor = 2
    scale = 1
    base = 2
    depth = 8
    visualize_depth = 4
    eps = 1e-8

    pconv = []
    for i in range(depth):
        pconv.append(MaskedPixelUnshuffle(downscale_factor))
        pconv.append(PartialConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multi_channel=True, return_mask=True)
        )
    pconv = MaskedSequential(*pconv)

    mpconv = []
    for i in range(depth):
        mpconv.append(MaskedPixelUnshuffle(downscale_factor))
        mpconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=True)
        )
    mpconv = MaskedSequential(*mpconv)

    mconv = []
    for i in range(depth):
        mconv.append(MaskedPixelUnshuffle(downscale_factor))
        mconv.append(MaskedConv2d(
            in_channels=scale * base ** (i + 1) * downscale_factor ** 2 if i > 0 else in_channels * downscale_factor ** 2,
            out_channels=scale * base ** i * downscale_factor ** 2,
            kernel_size=(3, 3), padding=1, bias=False, multichannel=True, partial_conv=False)
        )
    mconv = MaskedSequential(*mconv)

    with torch.no_grad():
        print(f"{pconv=}")
        print(f"{mpconv=}")
        print(f"{mconv=}")

        print(f"{list(pconv.state_dict().keys())=}")
        print(f"{list(mpconv.state_dict().keys())=}")
        print(f"{list(mconv.state_dict().keys())=}")

        mpconv.load_state_dict(pconv.state_dict())
        mconv.load_state_dict(pconv.state_dict())

        x = torch.randn(1, in_channels, downscale_factor**depth, downscale_factor**depth)
        mask_pconv, mask_mpconv, mask_mconv = torch.ones_like(x), torch.ones_like(x), torch.ones_like(x)

        def is_conv_predicate(name: str, module: torch.nn.Module):
            return isinstance(module, torch.nn.Conv2d)

        (y_pconv, mask_pconv), activations_pconv = forward_with_activations(pconv, is_conv_predicate, x, mask_pconv)
        (y_mpconv, mask_mpconv), activations_mpconv = forward_with_activations(mpconv, is_conv_predicate, x, mask_mpconv)
        (y_mconv, mask_mconv), activations_mconv = forward_with_activations(mconv, is_conv_predicate, x, mask_mconv)

        assert torch.allclose(y_mpconv, y_pconv)
        assert not torch.allclose(y_mconv, y_mpconv)

        print(f"{activations_pconv.keys()=}")  # ['1', '3', '5', '7', '9', '11', '13', '15']

        # fig, axs = plt.subplots(nrows=visualize_depth, ncols=3, figsize=(12, 8), dpi=180)
        fig, axs = plt.subplots(nrows=3, ncols=visualize_depth, figsize=(12, 8), dpi=180)
        axs = axs.flatten()

        for impl_i, (name, y, mask, activations) in enumerate([
            ("pconv", y_pconv, mask_pconv, activations_pconv),
            ("mpconv", y_mpconv, mask_mpconv, activations_mpconv),
            ("mconv", y_mconv, mask_mconv, activations_mconv)
        ]):
            batch_i = 0
            for depth_i in range(visualize_depth):
                # ax = axs[depth_i * 3 + impl_i]
                ax = axs[impl_i * visualize_depth + depth_i]

                output = activations[f"{depth_i * 2 + 1}"][0][batch_i]
                mask_output = activations[f"{depth_i * 2 + 1}"][1][batch_i]

                mean = output.mean()
                std = output.std(unbiased=False)
                skewness = ((output - mean) ** 3).mean() / (std ** 3 + eps)
                kurtosis = ((output - mean) ** 4).mean() / (std ** 4 + eps)
                print(f"{name=}, {depth_i=}, {mean=}, {std=}, {skewness=}, {kurtosis=}")

                ax.imshow(output.mean(dim=0).numpy(), cmap='coolwarm', vmin=-std, vmax=std)
                ax.set_title(f"{name} {depth_i=}")
                ax.axis('off')

        # plt.suptitle(f"Depth {depth_i}")
        plt.show()


if __name__ == '__main__':
    test_it()

Output:

name='pconv', depth_i=0, mean=tensor(-0.0040), std=tensor(0.5844), skewness=tensor(0.0056), kurtosis=tensor(3.0593)
name='pconv', depth_i=1, mean=tensor(-0.0014), std=tensor(0.3347), skewness=tensor(-0.0053), kurtosis=tensor(3.1046)
name='pconv', depth_i=2, mean=tensor(-0.0001), std=tensor(0.1993), skewness=tensor(0.0125), kurtosis=tensor(3.2002)
name='pconv', depth_i=3, mean=tensor(-0.0013), std=tensor(0.1211), skewness=tensor(-0.0061), kurtosis=tensor(3.5512)
name='mpconv', depth_i=0, mean=tensor(-0.0040), std=tensor(0.5844), skewness=tensor(0.0056), kurtosis=tensor(3.0593)
name='mpconv', depth_i=1, mean=tensor(-0.0014), std=tensor(0.3347), skewness=tensor(-0.0053), kurtosis=tensor(3.1046)
name='mpconv', depth_i=2, mean=tensor(-0.0001), std=tensor(0.1993), skewness=tensor(0.0125), kurtosis=tensor(3.2002)
name='mpconv', depth_i=3, mean=tensor(-0.0013), std=tensor(0.1211), skewness=tensor(-0.0061), kurtosis=tensor(3.5512)
name='mconv', depth_i=0, mean=tensor(-0.0039), std=tensor(0.5769), skewness=tensor(0.0052), kurtosis=tensor(3.0468)
name='mconv', depth_i=1, mean=tensor(-0.0016), std=tensor(0.3209), skewness=tensor(-0.0099), kurtosis=tensor(3.0444)
name='mconv', depth_i=2, mean=tensor(-0.0003), std=tensor(0.1796), skewness=tensor(-0.0102), kurtosis=tensor(3.1047)
name='mconv', depth_i=3, mean=tensor(-0.0011), std=tensor(0.0973), skewness=tensor(-0.0421), kurtosis=tensor(3.3349)

image

  • pconv is an original implementation of partial conv (this repo)
  • mpconv is my implementation of partial conv
  • mconv is my approach of masked convolution

Here is also activations on real images:


  • PartialConv2d:
    image

  • MaskedConv2d:
    image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions