Skip to content
This repository was archived by the owner on Mar 1, 2025. It is now read-only.
This repository was archived by the owner on Mar 1, 2025. It is now read-only.

Rewrite for convolution operation #241

@CheungBH

Description

@CheungBH

Thanks for your great work.
I want to use the inference of sparse conv operation, but the code doesn't provide such a function. Therefore, I am rewriting it like this.
When only_forward=True, the input will be processed directly without the operation of ctx.
However, I found there are nan using such a method. Do you have any ideas for solving it?

class ConvolutionFunction(Function):

@staticmethod
def forward(
        ctx,
        input_features,
        weight,
        bias,
        input_metadata,
        input_spatial_size,
        output_spatial_size,
        dimension,
        filter_size,
        filter_stride,
        only_forward=False):
    output_features = input_features.new()
    if only_forward:
        sparseconvnet.SCN.Convolution_updateOutput(
            input_spatial_size,
            output_spatial_size,
            filter_size,
            filter_stride,
            input_metadata,
            input_features,
            output_features,
            weight,
            bias)
        return output_features

    output_features = input_features.new()
    ctx.input_metadata = input_metadata
    ctx.dimension = dimension
    ctx.save_for_backward(
        input_features,
        input_spatial_size,
        weight,
        bias,
        output_spatial_size,
        filter_size,
        filter_stride)
    sparseconvnet.forward_pass_multiplyAdd_count +=\
        sparseconvnet.SCN.Convolution_updateOutput(
            input_spatial_size,
            output_spatial_size,
            filter_size,
            filter_stride,
            input_metadata,
            input_features,
            output_features,
            weight,
            bias)
    sparseconvnet.forward_pass_hidden_states += output_features.nelement()
    return output_features

@staticmethod
def backward(ctx, grad_output):
    input_features, input_spatial_size, weight, bias, output_spatial_size, filter_size, filter_stride = ctx.saved_tensors
    grad_input = grad_output.new()
    grad_weight = torch.zeros_like(weight)
    grad_bias = torch.zeros_like(bias)
    sparseconvnet.SCN.Convolution_backward(
        input_spatial_size,
        output_spatial_size,
        filter_size,
        filter_stride,
        ctx.input_metadata,
        input_features,
        grad_input,
        grad_output.contiguous(),
        weight,
        grad_weight,
        grad_bias)
    return grad_input, grad_weight, optionalTensorReturn(grad_bias), None, None, None, None, None, None

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions