Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .jules/thunderbolt.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
## 2024-05-20 - AVX2 Vectorized Softmax with single-FMA range reduction and 8x max unroll

**Learning:** Replaces the 2-FMA Cody-Waite range reduction in `exp256` with a single FMA using `ln(2)`, removing an instruction from the critical path while remaining within ML precision tolerances. Additionally, unrolling the max reduction 8x (from 4x) to better saturate execution ports yields measurable throughput improvements over `softmax_v5` implementation on larger inputs and fixed memory configurations (e.g. N=1048576, GFLOP/s improved from 3.56 to 3.78).

**Evidence:** End-to-end framework benchmarks showed an increase in GFLOP/s for N=1048576 (Fixed Memory) from 3.56 to 3.78 and for N=262144 (Fixed Memory) from 4.00 to 4.18.
Comment on lines +1 to +5
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Fix benchmark metadata consistency (date + unit/value).

Line 1 uses 2024-05-20, but this PR was created on 2026-05-20. Also Line 3/Line 5 cites 3.56 -> 3.78 as GFLOP/s, while the PR objective reports 3.57 -> 3.78 for GB/s at N=1,048,576. Please align the log entry to the measured run metadata to avoid ambiguity.

✏️ Suggested doc patch
-## 2024-05-20 - AVX2 Vectorized Softmax with single-FMA range reduction and 8x max unroll
+## 2026-05-20 - AVX2 Vectorized Softmax with single-FMA range reduction and 8x max unroll
@@
-**Learning:** Replaces the 2-FMA Cody-Waite range reduction in `exp256` with a single FMA using `ln(2)`, removing an instruction from the critical path while remaining within ML precision tolerances. Additionally, unrolling the max reduction 8x (from 4x) to better saturate execution ports yields measurable throughput improvements over `softmax_v5` implementation on larger inputs and fixed memory configurations (e.g. N=1048576, GFLOP/s improved from 3.56 to 3.78).
+**Learning:** Replaces the 2-FMA Cody-Waite range reduction in `exp256` with a single FMA using `ln(2)`, removing an instruction from the critical path while remaining within ML precision tolerances. Additionally, unrolling the max reduction 8x (from 4x) to better saturate execution ports yields measurable throughput improvements over `softmax_v5` on larger fixed-memory inputs (e.g. N=1048576, throughput improved from 3.57 to 3.78 GB/s).
@@
-**Evidence:** End-to-end framework benchmarks showed an increase in GFLOP/s for N=1048576 (Fixed Memory) from 3.56 to 3.78 and for N=262144 (Fixed Memory) from 4.00 to 4.18.
+**Evidence:** End-to-end framework benchmarks showed throughput at N=1048576 (Fixed Memory) improving from 3.57 to 3.78 GB/s, and GFLOP/s at N=262144 (Fixed Memory) improving from 4.00 to 4.18.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In @.jules/thunderbolt.md around lines 1 - 5, Update the log entry date to
2026-05-20 and correct the benchmark metadata: in the Evidence line that
mentions N=1048576 (Fixed Memory) replace the unit and value "GFLOP/s improved
from 3.56 to 3.78" with "GB/s improved from 3.57 to 3.78" so it matches the PR
objective; keep references to the implementation names exp256 and softmax_v5
as-is to preserve context.


**Action:** In transcendental AVX2 SIMD approximations, combining constants for `r = x - n * ln(2)` into a single FMA instruction—rather than splitting `ln(2)` for exact precision—can significantly boost throughput while keeping results within typical ML numerical tolerances due to the shift-invariant nature of operations like softmax.

## 2024-10-24 - AVX2 Vectorized Softmax Implementation

**Learning:** When vectorizing transcendental functions like `exp` in AVX2, standard Horner's method (`p = _mm256_fmadd_ps(p, r, c)`) creates a strict dependency chain bounded by the 4-cycle FMA latency. Estrin's scheme can break this chain and yield higher ILP. Additionally, standard library headers like `<algorithm>` for `std::max` should always be explicitly included even when not strictly required by the current benchmark/compiler, to avoid cross-platform compilation errors.
Expand Down
154 changes: 154 additions & 0 deletions ml_kernels/include/ml_kernels/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -501,4 +501,158 @@ inline void softmax_v5(const float *input, float *output, std::size_t n) {
}
}


inline __m256 exp256_ps_v3(__m256 x) {
x = _mm256_max_ps(x, _mm256_set1_ps(-87.3f));
__m256 x_log2e = _mm256_mul_ps(x, _mm256_set1_ps(1.4426950408889634f));

// cvtps_epi32 defaults to round-to-nearest in AVX2, avoiding round_ps
__m256i n_int = _mm256_cvtps_epi32(x_log2e);
__m256 n = _mm256_cvtepi32_ps(n_int);

// Use a single FMA for range reduction instead of splitting ln(2)
// ln(2) = 0.6931471805599453f
__m256 r = _mm256_fnmadd_ps(n, _mm256_set1_ps(0.6931471805599453f), x);

// Horner's scheme instead of Estrin
__m256 c1 = _mm256_set1_ps(1.0f);
__m256 c2 = _mm256_set1_ps(1.0f / 2.0f);
__m256 c3 = _mm256_set1_ps(1.0f / 6.0f);
__m256 c4 = _mm256_set1_ps(1.0f / 24.0f);
__m256 c5 = _mm256_set1_ps(1.0f / 120.0f);

__m256 p = _mm256_fmadd_ps(c5, r, c4);
p = _mm256_fmadd_ps(p, r, c3);
p = _mm256_fmadd_ps(p, r, c2);
p = _mm256_fmadd_ps(p, r, c1);
p = _mm256_fmadd_ps(p, r, c1);

__m256i exp_shift = _mm256_add_epi32(n_int, _mm256_set1_epi32(127));
__m256i exp_shifted = _mm256_slli_epi32(exp_shift, 23);
__m256 exp2n = _mm256_castsi256_ps(exp_shifted);

return _mm256_mul_ps(p, exp2n);
}

// ⚡ Thunderbolt: AVX2 Vectorized Softmax with single-FMA range reduction and 8x max unroll
// Target: AVX2 (Haswell+)
// Reason: Replaces the 2-FMA Cody-Waite range reduction in `exp256` with a single FMA using `ln(2)`,
// removing an instruction from the critical path while remaining within ML precision tolerances.
// Additionally, unrolls the max reduction 8x (from 4x) to better saturate execution ports.
// Expected gain: Measurable throughput improvement over softmax_v5.
inline void softmax_v6(const float *input, float *output, std::size_t n) {
if (n == 0) return;

// 1. Find max (8x unrolled)
std::size_t i = 0;
__m256 max_v = _mm256_set1_ps(std::numeric_limits<float>::lowest());
__m256 max0 = max_v, max1 = max_v, max2 = max_v, max3 = max_v;
__m256 max4 = max_v, max5 = max_v, max6 = max_v, max7 = max_v;

for (; i + 63 < n; i += 64) {
max0 = _mm256_max_ps(max0, _mm256_loadu_ps(input + i));
max1 = _mm256_max_ps(max1, _mm256_loadu_ps(input + i + 8));
max2 = _mm256_max_ps(max2, _mm256_loadu_ps(input + i + 16));
max3 = _mm256_max_ps(max3, _mm256_loadu_ps(input + i + 24));
max4 = _mm256_max_ps(max4, _mm256_loadu_ps(input + i + 32));
max5 = _mm256_max_ps(max5, _mm256_loadu_ps(input + i + 40));
max6 = _mm256_max_ps(max6, _mm256_loadu_ps(input + i + 48));
max7 = _mm256_max_ps(max7, _mm256_loadu_ps(input + i + 56));
}
max0 = _mm256_max_ps(max0, max4);
max1 = _mm256_max_ps(max1, max5);
max2 = _mm256_max_ps(max2, max6);
max3 = _mm256_max_ps(max3, max7);
max0 = _mm256_max_ps(max0, max1);
max2 = _mm256_max_ps(max2, max3);
max0 = _mm256_max_ps(max0, max2);
for (; i + 7 < n; i += 8) {
max0 = _mm256_max_ps(max0, _mm256_loadu_ps(input + i));
}
float max_val = reduce_max(max0);
for (; i < n; ++i) max_val = std::max(max_val, input[i]);

__m256 max_vec = _mm256_set1_ps(max_val);

// 2. Compute exp and sum (4x unrolled to avoid register spill and balance latency)
i = 0;
__m256 sum0 = _mm256_setzero_ps();
__m256 sum1 = _mm256_setzero_ps();
__m256 sum2 = _mm256_setzero_ps();
__m256 sum3 = _mm256_setzero_ps();

for (; i + 31 < n; i += 32) {
__m256 x0 = _mm256_sub_ps(_mm256_loadu_ps(input + i), max_vec);
__m256 x1 = _mm256_sub_ps(_mm256_loadu_ps(input + i + 8), max_vec);
__m256 x2 = _mm256_sub_ps(_mm256_loadu_ps(input + i + 16), max_vec);
__m256 x3 = _mm256_sub_ps(_mm256_loadu_ps(input + i + 24), max_vec);

__m256 e0 = exp256_ps_v3(x0);
__m256 e1 = exp256_ps_v3(x1);
__m256 e2 = exp256_ps_v3(x2);
__m256 e3 = exp256_ps_v3(x3);

_mm256_storeu_ps(output + i, e0);
_mm256_storeu_ps(output + i + 8, e1);
_mm256_storeu_ps(output + i + 16, e2);
_mm256_storeu_ps(output + i + 24, e3);

sum0 = _mm256_add_ps(sum0, e0);
sum1 = _mm256_add_ps(sum1, e1);
sum2 = _mm256_add_ps(sum2, e2);
sum3 = _mm256_add_ps(sum3, e3);
}
sum0 = _mm256_add_ps(sum0, sum1);
sum2 = _mm256_add_ps(sum2, sum3);
sum0 = _mm256_add_ps(sum0, sum2);

for (; i + 7 < n; i += 8) {
__m256 x = _mm256_loadu_ps(input + i);
__m256 e = exp256_ps_v3(_mm256_sub_ps(x, max_vec));
_mm256_storeu_ps(output + i, e);
sum0 = _mm256_add_ps(sum0, e);
}

float sum_val = reduce_sum(sum0);
for (; i < n; ++i) {
float e = std::exp(input[i] - max_val);
output[i] = e;
sum_val += e;
}

if (sum_val == 0.0f) return;

// 3. Normalize
float inv_sum = 1.0f / sum_val;
__m256 inv_sum_v = _mm256_set1_ps(inv_sum);
i = 0;

// Unrolling normalize 8x to saturate execution ports better
for (; i + 63 < n; i += 64) {
__m256 o0 = _mm256_loadu_ps(output + i);
__m256 o1 = _mm256_loadu_ps(output + i + 8);
__m256 o2 = _mm256_loadu_ps(output + i + 16);
__m256 o3 = _mm256_loadu_ps(output + i + 24);
__m256 o4 = _mm256_loadu_ps(output + i + 32);
__m256 o5 = _mm256_loadu_ps(output + i + 40);
__m256 o6 = _mm256_loadu_ps(output + i + 48);
__m256 o7 = _mm256_loadu_ps(output + i + 56);

_mm256_storeu_ps(output + i, _mm256_mul_ps(o0, inv_sum_v));
_mm256_storeu_ps(output + i + 8, _mm256_mul_ps(o1, inv_sum_v));
_mm256_storeu_ps(output + i + 16, _mm256_mul_ps(o2, inv_sum_v));
_mm256_storeu_ps(output + i + 24, _mm256_mul_ps(o3, inv_sum_v));
_mm256_storeu_ps(output + i + 32, _mm256_mul_ps(o4, inv_sum_v));
_mm256_storeu_ps(output + i + 40, _mm256_mul_ps(o5, inv_sum_v));
_mm256_storeu_ps(output + i + 48, _mm256_mul_ps(o6, inv_sum_v));
_mm256_storeu_ps(output + i + 56, _mm256_mul_ps(o7, inv_sum_v));
}
for (; i + 7 < n; i += 8) {
_mm256_storeu_ps(output + i, _mm256_mul_ps(_mm256_loadu_ps(output + i), inv_sum_v));
}
for (; i < n; ++i) {
output[i] *= inv_sum;
}
}

} // namespace ml_kernels
11 changes: 11 additions & 0 deletions ml_kernels/src/kernel_bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,17 @@ class SoftmaxV5Benchmark : public SoftmaxBenchmark {
};
REGISTER_BENCHMARK(SoftmaxV5Benchmark);

class SoftmaxV6Benchmark : public SoftmaxBenchmark {
public:
const char *name() const override { return "softmax_v6"; }

void run() override {
ml_kernels::softmax_v6(inputs_[current_idx_].data(), outputs_[current_idx_].data(), inputs_[0].size());
current_idx_ = (current_idx_ + 1) % pool_size_;
}
};
REGISTER_BENCHMARK(SoftmaxV6Benchmark);

} // namespace

int main(int argc, char **argv) {
Expand Down
30 changes: 30 additions & 0 deletions ml_kernels/src/test_naive_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,41 @@ void test_softmax_v5() {
std::cout << "test_softmax_v5 passed!" << std::endl;
}

void test_softmax_v6() {
std::cout << "Running test_softmax_v6..." << std::endl;
std::vector<float> input = {
-2.0f, -0.5f, 1.0f, 3.0f,
0.0f, 0.0f, 0.0f, 0.0f,
100.0f, 100.0f, -100.0f, -100.0f,
5.0f, -5.0f, 2.0f, -2.0f,
1.1f, 1.2f, 1.3f, 1.4f,
-1.1f, -1.2f, -1.3f, -1.4f,
10.0f, 20.0f, 30.0f, 40.0f,
-10.0f, -20.0f, -30.0f, -40.0f
};

std::vector<float> output_naive(input.size(), 0.0f);
std::vector<float> output_v6(input.size(), 0.0f);

ml_kernels::softmax_naive(input.data(), output_naive.data(), input.size());
ml_kernels::softmax_v6(input.data(), output_v6.data(), input.size());

float sum = 0.0f;
for (std::size_t i = 0; i < input.size(); ++i) {
assert(std::fabs(output_naive[i] - output_v6[i]) < 1e-4f);
sum += output_v6[i];
}
assert(std::fabs(sum - 1.0f) < 1e-4f);

std::cout << "test_softmax_v6 passed!" << std::endl;
}

int main() {
test_relu_naive();
test_max_naive();
test_softmax_v3();
test_softmax_v4();
test_softmax_v5();
test_softmax_v6();
std::cout << "All tests passed successfully!" << std::endl;
}
Loading