Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/copilot-instructions.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,38 @@ Use inline methods where possible to avoid dispatch overhead.
## GitHub Actions CI
The project uses GitHub Actions for CI/CD

## Gotchas

### JMH `Unit`-returning benchmarks and the Vector API scalar-replacement cliff

**Symptom**: A JMH benchmark that calls SIMD code (Vector API) shows a sudden, dramatic throughput cliff (100–500×) at a specific array size threshold, accompanied by massively inflated GC allocation (`gc.alloc.rate.norm` jumping from the expected one-array-worth to 30–70× that, e.g. 65 KB/op becomes 2.4 MB/op). The warmup iterations look healthy; the compiled measurement iterations are catastrophically slow. The compiled code is *slower* than the interpreter.

**Root cause**: C2 fails to scalar-replace `FloatVector`/`VectorMask` objects when the enclosing JMH benchmark method has a `Unit` return type (`def bench(bh: Blackhole): Unit`). The JIT sees the `bh` reference pre-loaded at the bottom of the operand stack and its escape analysis incorrectly concludes the transient Vector objects escape. They are heap-allocated every iteration, causing cascading GC pressure.

**The library code is not broken.** The same SIMD path invoked from a non-void return method measures correctly (expected alloc, linear scaling). Do not spend time changing `logicalFloatIdx`, `spf` declarations, loop structure, or any library code in response to this symptom — it is a benchmark artefact.

**Fix**: Change the benchmark method to return its result explicitly:
```scala
// BAD — triggers the cliff
@Benchmark
def my_op(bh: Blackhole): Unit =
bh.consume(arr > 0.0f)

// GOOD — C2 scalar-replaces Vector objects correctly
@Benchmark
def my_op(bh: Blackhole): Array[Boolean] =
val result = arr > 0.0f
bh.consume(result)
result
```

**When diagnosing a performance regression**:
1. First check `gc.alloc.rate.norm` with `-prof gc`. Expected alloc is `sizeof(output array) + small constant`.
2. If alloc is 30–70× too high, suspect this JMH anti-pattern before touching library code.
3. Confirm by adding a `_returning` variant that returns the result — if it measures correctly, the benchmark method signature is the culprit.

**Not affected**: Benchmarks that mutate in-place (`Unit` is fine), or those where `bh.consume` is called on a pre-existing field/variable that was allocated outside the hot loop.

## Vecxt Re

Contains a bunch of domain specific code for reinsurance calculations, structures, and various reinsurance contract types. It will often rely on Vecxt. You should view the principles as the same - correctness above all else - performance matters. It also aims to eexpose a consistent cross platform API.
289 changes: 289 additions & 0 deletions benchmark/src/mnistBenchmark.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
package vecxt.benchmark

import org.openjdk.jmh.annotations.*
import org.openjdk.jmh.infra.Blackhole
import vecxt.all.*
import vecxt.BoundsCheck
import BoundsCheck.DoBoundsCheck.no
import scala.compiletime.uninitialized

import java.util.concurrent.TimeUnit

/** Breaks down the MNIST forward + backward pass into individual operations to identify where time is spent. Uses
* realistic dimensions from the MNIST training loop: batchSize x 784 input, 784x128 hidden, 128x10 output.
*/

// ./mill benchmark.runJmh "vecxt.benchmark.MnistBenchmark" -jvmArgs "--add-modules=jdk.incubator.vector"
@State(Scope.Thread)
class MnistBenchmark extends BLASBenchmark:

// MNIST network dimensions
val imageSize = 784 // 28*28
val hiddenSize = 128
val outputSize = 10

@Param(Array("128", "512"))
var batchSize: String = uninitialized

var bs: Int = uninitialized

// Forward pass inputs
var xBatch: Matrix[Float] = uninitialized // (bs, 784)
var w1: Matrix[Float] = uninitialized // (784, 128)
var b1: Array[Float] = uninitialized // (128)
var w2: Matrix[Float] = uninitialized // (128, 10)
var b2: Array[Float] = uninitialized // (10)

// Forward pass intermediates (for backward benchmarks)
var z1: Matrix[Float] = uninitialized // (bs, 128)
var a1: Matrix[Float] = uninitialized // (bs, 128)
var z2: Matrix[Float] = uninitialized // (bs, 10)
var a2: Matrix[Float] = uninitialized // (bs, 10)
var yBatch: Matrix[Float] = uninitialized // (bs, 10) one-hot

// Backward intermediates
var dz2: Matrix[Float] = uninitialized // (bs, 10)
var dz1: Matrix[Float] = uninitialized // (bs, 128)
var dz1Check: Matrix[Boolean] = uninitialized
var a1T: Matrix[Float] = uninitialized // (128, bs)
var xT: Matrix[Float] = uninitialized // (784, bs)
var w2T: Matrix[Float] = uninitialized // (10, 128)

// Weight update intermediates
var dw1: Matrix[Float] = uninitialized
var dw2: Matrix[Float] = uninitialized
var db1: Array[Float] = uninitialized
var db2: Array[Float] = uninitialized

@Setup(Level.Trial)
def setup: Unit =
bs = batchSize.toInt

xBatch = Matrix(randomFloatArray(bs * imageSize), (bs, imageSize))
w1 = Matrix(randomFloatArray(imageSize * hiddenSize).map(_ * 0.2f), (imageSize, hiddenSize))
b1 = randomFloatArray(hiddenSize)
w2 = Matrix(randomFloatArray(hiddenSize * outputSize).map(_ * 0.2f), (hiddenSize, outputSize))
b2 = randomFloatArray(outputSize)

// One-hot labels
val yRaw = Array.fill(bs * outputSize)(0.0f)
val rng = new java.util.Random(42)
var i = 0
while i < bs do
yRaw(i + bs * rng.nextInt(outputSize)) = 1.0f
i += 1
end while
yBatch = Matrix(yRaw, (bs, outputSize))

// Pre-compute forward pass intermediates for backward benchmarks
z1 = xBatch @@ w1
z1.mapRowsInPlace { r => r += b1; r }
a1 = Matrix(z1.raw.clampMin(0.0f), z1.shape)
z2 = a1 @@ w2
z2.mapRowsInPlace { r => r += b2; r }
a2 = softmaxRowsBench(z2.deepCopy)

dz2 = a2 - yBatch
dz1Check = z1 > 0
a1T = a1.transpose
xT = xBatch.transpose
w2T = w2.transpose

dz1 = (dz2 @@ w2T)
dz1 *:*= dz1Check

val m_inv = 1.0f / bs
dw1 = m_inv * (xT @@ dz1)
dw2 = m_inv * (a1T @@ dz2)
db1 = dz1.colSums.tap(_ *= m_inv)
db2 = dz2.colSums
()
end setup

// ============================================================
// FORWARD PASS — individual operations
// ============================================================

@Benchmark
def fwd_01_matmul_x_w1(bh: Blackhole): Unit =
// (bs, 784) @@ (784, 128) — the big matmul
bh.consume(xBatch @@ w1)

@Benchmark
def fwd_02_bias_add_b1(bh: Blackhole): Unit =
val z = z1.deepCopy
z.mapRowsInPlace { r => r += b1; r }
bh.consume(z)
end fwd_02_bias_add_b1

@Benchmark
def fwd_03_relu(bh: Blackhole): Unit =
bh.consume(Matrix(z1.raw.clampMin(0.0f), z1.shape))

@Benchmark
def fwd_04_matmul_a1_w2(bh: Blackhole): Unit =
// (bs, 128) @@ (128, 10)
bh.consume(a1 @@ w2)

@Benchmark
def fwd_05_bias_add_b2(bh: Blackhole): Unit =
val z = z2.deepCopy
z.mapRowsInPlace { r => r += b2; r }
bh.consume(z)
end fwd_05_bias_add_b2

@Benchmark
def fwd_06_softmax(bh: Blackhole): Unit =
bh.consume(softmaxRowsBench(z2.deepCopy))

@Benchmark
def fwd_full_forward(bh: Blackhole): Unit =
val z1_ = xBatch @@ w1
z1_.mapRowsInPlace { r => r += b1; r }
val a1_ = Matrix(z1_.raw.clampMin(0.0f), z1_.shape)
val z2_ = a1_ @@ w2
z2_.mapRowsInPlace { r => r += b2; r }
val a2_ = softmaxRowsBench(z2_)
bh.consume(a2_)
end fwd_full_forward

// ============================================================
// BACKWARD PASS — individual operations
// ============================================================

@Benchmark
def bwd_01_dz2_sub(bh: Blackhole): Unit =
bh.consume(a2 - yBatch)

@Benchmark
def bwd_02_matmul_a1T_dz2(bh: Blackhole): Unit =
// (128, bs) @@ (bs, 10) — gradient for w2
bh.consume(a1T @@ dz2)

@Benchmark
def bwd_03_transpose_a1(bh: Blackhole): Unit =
bh.consume(a1.transpose)

@Benchmark
def bwd_04_db2_col_sum(bh: Blackhole): Array[Float] =
val result = dz2.colSums
bh.consume(result)
result
end bwd_04_db2_col_sum

@Benchmark
def bwd_05_relu_mask(bh: Blackhole): Matrix[Boolean] =
val result = z1 > 0
bh.consume(result)
result
end bwd_05_relu_mask

@Benchmark
def bwd_06_matmul_dz2_w2T(bh: Blackhole): Unit =
// (bs, 10) @@ (10, 128) — propagate error back
bh.consume(dz2 @@ w2T)

@Benchmark
def bwd_07_mask_multiply(bh: Blackhole): Unit =
val dz = (dz2 @@ w2T)
dz *:*= dz1Check
bh.consume(dz)
end bwd_07_mask_multiply

@Benchmark
def bwd_07b_zeroWhere(bh: Blackhole): Unit =
// Fused alternative: single SIMD pass, no boolean allocation
val dz = (dz2 @@ w2T)
dz.raw.`zeroWhere!`(z1.raw, 0.0f, ComparisonOp.LE)
bh.consume(dz)
end bwd_07b_zeroWhere

@Benchmark
def bwd_08_matmul_xT_dz1(bh: Blackhole): Unit =
// (784, bs) @@ (bs, 128) — gradient for w1, the big one
bh.consume(xT @@ dz1)

@Benchmark
def bwd_09_db1_col_sum(bh: Blackhole): Array[Float] =
val result = dz1.colSums
result *= (1.0f / bs)
bh.consume(result)
result
end bwd_09_db1_col_sum

@Benchmark
def bwd_full_backward(bh: Blackhole): Unit =
val m_inv = 1.0f / bs
val dz2_ = a2 - yBatch
val dw2_ = m_inv * (a1T @@ dz2_)
val db2_ = dz2_.colSums
val dz1_ = (dz2_ @@ w2T)
dz1_.raw.`zeroWhere!`(z1.raw, 0.0f, ComparisonOp.LE)
val dw1_ = m_inv * (xT @@ dz1_)
val db1_ = dz1_.colSums.tap(_ *= m_inv)
bh.consume(dw1_)
end bwd_full_backward

// ============================================================
// WEIGHT UPDATE
// ============================================================

// @Benchmark
// def upd_01_w1_update(bh: Blackhole): Unit =
// import BoundsCheck.DoBoundsCheck.yes
// val w = w1.deepCopy
// w -= (dw1 * 0.05f)
// bh.consume(w)

// @Benchmark
// def upd_02_w2_update(bh: Blackhole): Unit =
// import BoundsCheck.DoBoundsCheck.yes
// val w = w2.deepCopy
// w -= (dw2 * 0.05f)
// bh.consume(w)

// @Benchmark
// def upd_03_b1_update(bh: Blackhole): Unit =
// val b = b1.clone()
// b -= (db1 * 0.05f)
// bh.consume(b)

// ============================================================
// FULL STEP (forward + backward + update) for reference
// ============================================================

@Benchmark
def full_training_step(bh: Blackhole): Unit =
val alpha = 0.05f
val m_inv = 1.0f / bs
// Forward
val z1_ = xBatch @@ w1
z1_.mapRowsInPlace { r => r += b1; r }
val a1_ = Matrix(z1_.raw.clampMin(0.0f), z1_.shape)
val z2_ = a1_ @@ w2
z2_.mapRowsInPlace { r => r += b2; r }
val a2_ = softmaxRowsBench(z2_)
// Backward
val dz2_ = a2_ - yBatch
val dw2_ = m_inv * (a1_.transpose @@ dz2_)
val db2_ = dz2_.colSums
val dz1_ = dz2_ @@ w2.transpose
dz1_.raw.`zeroWhere!`(z1_.raw, 0.0f, ComparisonOp.LE)
val dw1_ = m_inv * (xBatch.transpose @@ dz1_)
val db1_ = dz1_.colSums.tap(_ *= m_inv)
// Update (consume results)
bh.consume(dw1_)
bh.consume(dw2_)
bh.consume(db1_)
bh.consume(db2_)
end full_training_step

private def softmaxRowsBench(z: Matrix[Float]): Matrix[Float] =
z.mapRows { row =>
row -= row.max
row.`exp!`
row /= row.sum
row
}

end MnistBenchmark
14 changes: 5 additions & 9 deletions experiments/src/mnist.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,14 @@ def back_prop(
val dz2 = a2 - Y
val dw2 = m_inv * (a1.transpose @@ dz2)

val db2 = dz2.mapColsToScalar(_.sum).raw
val dz1Check = (z1 > 0)
val db2 = dz2.sum(0).raw
// val dz1Check = (z1 > 0)
val dz1 = (dz2 @@ w2.transpose)
dz1 *:*= dz1Check
dz1.raw.`zeroWhere!`(z1.raw, 0.0, ComparisonOp.LE)

val dw1 = m_inv * (X.transpose @@ dz1)

val db1 = dz1.mapColsToScalar(r => r.sumSIMD * m_inv).raw
val db1 = dz1.sum(0).raw * m_inv
// println("back propagation (Float) done ----")
(dw1 = dw1, db1 = db1, dw2 = dw2, db2 = db2)
end back_prop
Expand Down Expand Up @@ -248,12 +248,8 @@ def back_prop(
// println(s"dz2 shape: ${dz2.shape}, dz2 rows: ${dz2.rows}, dz2 cols: ${dz2.cols}")
// println(s"dw2 shape: ${dw2.shape}, dw2 rows: ${dw2.rows}, dw2 cols: ${dw2.cols}")
val db2 = dz2.mapColsToScalar(_.sum).raw
val dz1Check = (z1 > 0)
// println(s"dz2 shape: ${dz2.shape}, dz2 rows: ${dz2.rows}, dz2 cols: ${dz2.cols}\n")
// println(s"dz1Check: ${dz1Check.shape}, dz1Check rows: ${dz1Check.rows}, dz1Check cols: ${dz1Check.cols}\n"``)
// println(s"dz1Check: ${dz1Check(0 to 10, ::).printMat}\n")
val dz1 = (dz2 @@ w2.transpose)
dz1 *:*= dz1Check // (10, 784)
dz1.raw.`zeroWhere!`(z1.raw, 0.0, ComparisonOp.LE) // (10, 784)
// print(s"dz1 shape: ${dz1.shape}, dz1 rows: ${dz1.rows}, dz1 cols: ${dz1.cols}\n")
val dw1 = m_inv * (X.transpose @@ dz1)
val db1 = dz1.mapColsToScalar(r => r.sumSIMD * m_inv).raw
Expand Down
Loading
Loading