Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions NAM/conv1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,104 @@

namespace nam
{
namespace
{
// Templated per-tap accumulating kernel for Conv1D.
// OutCh, InCh, Groups are compile-time constants so the compiler unrolls every loop
// and folds all index arithmetic. Off-block-diagonal zeros are never visited.
// Weight memory layout is col-major (out_channels rows x in_channels cols), matching
// Eigen::MatrixXf default storage in nam::Conv1D::_weight[k].
// Input layout is assumed contiguous (channels rows x num_frames cols, col-major), as
// the existing inline-GEMM cascade also assumes.
template <int OutCh, int InCh, int Groups>
void templated_conv1d_tap_kernel(const float* __restrict__ weight, const float* __restrict__ in,
float* __restrict__ out, int num_frames)
{
static_assert(OutCh % Groups == 0, "OutCh must be divisible by Groups");
static_assert(InCh % Groups == 0, "InCh must be divisible by Groups");
constexpr int OutPerG = OutCh / Groups;
constexpr int InPerG = InCh / Groups;
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = in + f * InCh;
float* __restrict__ out_col = out + f * OutCh;
for (int g = 0; g < Groups; g++)
{
const int o_base = g * OutPerG;
const int i_base = g * InPerG;
for (int o = 0; o < OutPerG; o++)
{
float sum = 0.0f;
for (int i = 0; i < InPerG; i++)
{
sum += weight[(i_base + i) * OutCh + (o_base + o)] * in_col[i_base + i];
}
out_col[o_base + o] += sum;
}
}
}
}

// Map (out_channels, in_channels, groups) -> templated tap-kernel function pointer.
// Returns nullptr for unregistered shapes; caller falls back to existing inline /
// Eigen GEMM cascade. Depthwise (groups == channels) is handled by Conv1D's existing
// _is_depthwise path and is intentionally not registered here.
nam::Conv1D::TapKernel pick_conv1d_tap_kernel(int out_channels, int in_channels, int groups)
{
using K = nam::Conv1D::TapKernel;
if (out_channels == 4 && in_channels == 4)
{
if (groups == 1)
return static_cast<K>(&templated_conv1d_tap_kernel<4, 4, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1d_tap_kernel<4, 4, 2>);
}
if (out_channels == 6 && in_channels == 6)
{
if (groups == 1)
return static_cast<K>(&templated_conv1d_tap_kernel<6, 6, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1d_tap_kernel<6, 6, 2>);
if (groups == 3)
return static_cast<K>(&templated_conv1d_tap_kernel<6, 6, 3>);
}
if (out_channels == 8 && in_channels == 8)
{
if (groups == 1)
return static_cast<K>(&templated_conv1d_tap_kernel<8, 8, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1d_tap_kernel<8, 8, 2>);
if (groups == 4)
return static_cast<K>(&templated_conv1d_tap_kernel<8, 8, 4>);
}
if (out_channels == 12 && in_channels == 12)
{
if (groups == 1)
return static_cast<K>(&templated_conv1d_tap_kernel<12, 12, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1d_tap_kernel<12, 12, 2>);
if (groups == 3)
return static_cast<K>(&templated_conv1d_tap_kernel<12, 12, 3>);
if (groups == 4)
return static_cast<K>(&templated_conv1d_tap_kernel<12, 12, 4>);
if (groups == 6)
return static_cast<K>(&templated_conv1d_tap_kernel<12, 12, 6>);
}
if (out_channels == 16 && in_channels == 16)
{
if (groups == 1)
return static_cast<K>(&templated_conv1d_tap_kernel<16, 16, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1d_tap_kernel<16, 16, 2>);
if (groups == 4)
return static_cast<K>(&templated_conv1d_tap_kernel<16, 16, 4>);
if (groups == 8)
return static_cast<K>(&templated_conv1d_tap_kernel<16, 16, 8>);
}
return nullptr;
}
} // namespace

// Conv1D =====================================================================

void Conv1D::set_weights_(std::vector<float>::iterator& weights)
Expand Down Expand Up @@ -86,6 +184,7 @@ void Conv1D::set_size_(const int in_channels, const int out_channels, const int
this->_depthwise_weight[i].setZero();
}
this->_weight.clear(); // Not used for depthwise
this->_tap_kernel = nullptr;
}
else
{
Expand All @@ -99,6 +198,10 @@ void Conv1D::set_size_(const int in_channels, const int out_channels, const int
}
this->_depthwise_weight.clear(); // Not used for non-depthwise
this->_channels = 0;
// Look up a shape-specialized templated per-tap kernel. Skips zeros for grouped
// cases and bypasses Eigen GEMM for small dense cases. nullptr -> fall back to
// existing inline / Eigen GEMM cascade.
this->_tap_kernel = pick_conv1d_tap_kernel(out_channels, in_channels, groups);
}

if (do_bias)
Expand Down Expand Up @@ -251,6 +354,21 @@ void Conv1D::Process(const Eigen::MatrixXf& input, const int num_frames)
}
#endif
}
else if (this->_tap_kernel != nullptr)
{
// Shape-specialized templated per-tap kernel (constexpr-unrolled, skips off-diagonal
// zeros for grouped cases). Accumulates across taps so output must be zeroed first.
_output.leftCols(num_frames).setZero();
const size_t kernel_size = this->_weight.size();
float* __restrict__ output_ptr = _output.data();
for (size_t k = 0; k < kernel_size; k++)
{
const long offset = this->_dilation * (k + 1 - (long)kernel_size);
const long lookback = -offset;
auto input_block = _input_buffer.Read(num_frames, lookback);
this->_tap_kernel(this->_weight[k].data(), input_block.data(), output_ptr, num_frames);
}
}
else
{
#ifdef NAM_USE_INLINE_GEMM
Expand Down Expand Up @@ -736,6 +854,20 @@ void Conv1D::process_(const Eigen::MatrixXf& input, Eigen::MatrixXf& output, con
this->_depthwise_weight[k].asDiagonal() * input.middleCols(i_start + offset, ncols);
}
}
else if (this->_tap_kernel != nullptr && input.outerStride() == input.rows() && output.outerStride() == output.rows())
{
// Shape-specialized templated per-tap kernel; accumulates so zero the output slice first.
// Guarded by the stride check because the kernel assumes contiguous column-major storage.
output.middleCols(j_start, ncols).setZero();
float* __restrict__ out_ptr = output.data() + j_start * output.rows();
const size_t kernel_size = this->_weight.size();
for (size_t k = 0; k < kernel_size; k++)
{
const long offset = this->_dilation * (k + 1 - (long)kernel_size);
const float* __restrict__ in_ptr = input.data() + (i_start + offset) * input.rows();
this->_tap_kernel(this->_weight[k].data(), in_ptr, out_ptr, (int)ncols);
}
}
else
{
// Grouped convolution note: The weight matrices are block-diagonal (zeros off-diagonal),
Expand Down
11 changes: 11 additions & 0 deletions NAM/conv1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ class Conv1D
/// \return true if bias is present, false otherwise
bool has_bias() const { return this->_bias.size() > 0; };

// Function pointer to a shape-specialized per-tap GEMM kernel that accumulates into
// the output buffer (out += weight * in). Public so the dispatch table in conv1d.cpp
// can return values of this type without exposing internals.
using TapKernel = void (*)(const float* weight, const float* in, float* out, int num_frames);

protected:
// conv[kernel](cout, cin) - used for non-depthwise convolutions
std::vector<Eigen::MatrixXf> _weight;
Expand All @@ -129,6 +134,12 @@ class Conv1D
int _dilation;
int _num_groups;

// Set at construction time when (in_channels, out_channels, groups) matches a
// registered template specialization. When non-null, the non-depthwise Process /
// process_ paths invoke this per tap instead of running a dense Eigen / inline GEMM
// through the block-diagonal zero structure. nullptr -> fall back to generic.
TapKernel _tap_kernel = nullptr;

private:
RingBuffer _input_buffer; // Ring buffer for input (channels x buffer_size)
Eigen::MatrixXf _output; // Pre-allocated output buffer (out_channels x maxBufferSize)
Expand Down
120 changes: 118 additions & 2 deletions NAM/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,106 @@ static nam::ConfigParserHelper _register_Linear("Linear", nam::linear::create_co

// Conv1x1 ====================================================================

namespace
{
// Templated dense/grouped 1x1 kernel.
// OutCh, InCh, Groups are compile-time constants so the compiler unrolls every loop
// and folds all index arithmetic. Off-block-diagonal zeros are never visited.
// Weight memory layout is col-major (out_channels rows x in_channels cols) -
// matching Eigen::MatrixXf default storage in nam::Conv1x1::_weight.
template <int OutCh, int InCh, int Groups>
void templated_conv1x1_kernel(const float* __restrict__ weight, const float* __restrict__ in, float* __restrict__ out,
int num_frames, int in_stride)
{
static_assert(OutCh % Groups == 0, "OutCh must be divisible by Groups");
static_assert(InCh % Groups == 0, "InCh must be divisible by Groups");
constexpr int OutPerG = OutCh / Groups;
constexpr int InPerG = InCh / Groups;
for (int f = 0; f < num_frames; f++)
{
const float* __restrict__ in_col = in + f * in_stride;
float* __restrict__ out_col = out + f * OutCh;
for (int g = 0; g < Groups; g++)
{
constexpr int row_offset_per_group = OutPerG;
constexpr int col_offset_per_group = InPerG;
const int o_base = g * row_offset_per_group;
const int i_base = g * col_offset_per_group;
for (int o = 0; o < OutPerG; o++)
{
float sum = 0.0f;
for (int i = 0; i < InPerG; i++)
{
sum += weight[(i_base + i) * OutCh + (o_base + o)] * in_col[i_base + i];
}
out_col[o_base + o] = sum;
}
}
}
}

// Map (out_channels, in_channels, groups) -> templated kernel function pointer.
// Returns nullptr when no specialization is registered; caller falls back to the
// generic Eigen / inline-GEMM path.
nam::Conv1x1::ProcessKernel pick_conv1x1_kernel(int out_channels, int in_channels, int groups)
{
using K = nam::Conv1x1::ProcessKernel;
// Square shapes (the layer1x1 / head1x1 / FiLM cases that dominate WaveNet).
// Depthwise (groups == channels) is handled by the dedicated _is_depthwise path
// and is intentionally not registered here.
if (out_channels == 4 && in_channels == 4)
{
if (groups == 1)
return static_cast<K>(&templated_conv1x1_kernel<4, 4, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1x1_kernel<4, 4, 2>);
}
if (out_channels == 6 && in_channels == 6)
{
if (groups == 1)
return static_cast<K>(&templated_conv1x1_kernel<6, 6, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1x1_kernel<6, 6, 2>);
if (groups == 3)
return static_cast<K>(&templated_conv1x1_kernel<6, 6, 3>);
}
if (out_channels == 8 && in_channels == 8)
{
if (groups == 1)
return static_cast<K>(&templated_conv1x1_kernel<8, 8, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1x1_kernel<8, 8, 2>);
if (groups == 4)
return static_cast<K>(&templated_conv1x1_kernel<8, 8, 4>);
}
if (out_channels == 12 && in_channels == 12)
{
if (groups == 1)
return static_cast<K>(&templated_conv1x1_kernel<12, 12, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1x1_kernel<12, 12, 2>);
if (groups == 3)
return static_cast<K>(&templated_conv1x1_kernel<12, 12, 3>);
if (groups == 4)
return static_cast<K>(&templated_conv1x1_kernel<12, 12, 4>);
if (groups == 6)
return static_cast<K>(&templated_conv1x1_kernel<12, 12, 6>);
}
if (out_channels == 16 && in_channels == 16)
{
if (groups == 1)
return static_cast<K>(&templated_conv1x1_kernel<16, 16, 1>);
if (groups == 2)
return static_cast<K>(&templated_conv1x1_kernel<16, 16, 2>);
if (groups == 4)
return static_cast<K>(&templated_conv1x1_kernel<16, 16, 4>);
if (groups == 8)
return static_cast<K>(&templated_conv1x1_kernel<16, 16, 8>);
}
return nullptr;
}
} // namespace

nam::Conv1x1::Conv1x1(const int in_channels, const int out_channels, const bool _bias, const int groups)
{
// Validate that channels divide evenly by groups
Expand Down Expand Up @@ -376,6 +476,9 @@ nam::Conv1x1::Conv1x1(const int in_channels, const int out_channels, const bool
this->_weight.resize(out_channels, in_channels);
this->_weight.setZero();
this->_channels = 0;
// Look up a shape-specialized templated kernel. Skips zeros for grouped cases and
// bypasses Eigen GEMM for small dense cases. nullptr -> fall back to generic kernel.
this->_kernel = pick_conv1x1_kernel(out_channels, in_channels, groups);
}

if (_bias)
Expand Down Expand Up @@ -452,9 +555,14 @@ Eigen::MatrixXf nam::Conv1x1::process(const Eigen::MatrixXf& input, const int nu
// Each channel is scaled by its corresponding weight
result.noalias() = this->_depthwise_weight.asDiagonal() * input.leftCols(num_frames);
}
else if (this->_kernel != nullptr)
{
// Shape-specialized templated kernel (constexpr-unrolled, skips off-diagonal zeros).
this->_kernel(this->_weight.data(), input.data(), result.data(), num_frames, (int)input.outerStride());
}
else
{
// Single GEMM for all cases - block-diagonal zero structure handles grouping
// Generic fallback: single dense GEMM through the block-diagonal zero structure.
result.noalias() = this->_weight * input.leftCols(num_frames);
}

Expand All @@ -477,6 +585,12 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
// Each channel is scaled by its corresponding weight
_output.leftCols(num_frames).noalias() = this->_depthwise_weight.asDiagonal() * input.leftCols(num_frames);
}
else if (this->_kernel != nullptr)
{
// Shape-specialized templated kernel (constexpr-unrolled, skips off-diagonal zeros
// for grouped cases). Bias is applied after this block by the shared bias path.
this->_kernel(this->_weight.data(), input.data(), _output.data(), num_frames, (int)input.outerStride());
}
else
{
#ifdef NAM_USE_INLINE_GEMM
Expand Down Expand Up @@ -745,7 +859,9 @@ void nam::Conv1x1::process_(const Eigen::Ref<const Eigen::MatrixXf>& input, cons
}
}
#else
// Single GEMM for all cases - block-diagonal zero structure handles grouping
// Single GEMM for all cases - block-diagonal zero structure handles grouping.
// Per-group Eigen blocks were tried but small-block GEMM overhead dominates;
// see the inline-GEMM path above for grouped-specific kernels.
_output.leftCols(num_frames).noalias() = this->_weight * input.leftCols(num_frames);
#endif
}
Expand Down
10 changes: 10 additions & 0 deletions NAM/dsp.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ class Conv1x1
long get_out_channels() const;
long get_in_channels() const;

// Function pointer to a shape-specialized GEMM kernel. Public so the dispatch table
// in dsp.cpp can return values of this type without exposing internals.
using ProcessKernel = void (*)(const float* weight, const float* in, float* out, int num_frames, int in_stride);

protected:
// Non-depthwise: full weight matrix (out_channels x in_channels)
Eigen::MatrixXf _weight;
Expand All @@ -363,6 +367,12 @@ class Conv1x1
Eigen::VectorXf _bias;
int _num_groups;

// Set at construction time when (in_channels, out_channels, groups) matches a
// registered template specialization. When non-null, used by both the Eigen and
// inline-GEMM process_ paths in preference to the generic dense kernel.
// nullptr -> fall back to generic.
ProcessKernel _kernel = nullptr;

private:
Eigen::MatrixXf _output;
bool _do_bias;
Expand Down
Loading