Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions compute/arena_size_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package compute

import (
"testing"

"github.com/zerfoo/ztensor/log"
)

func TestArenaSizeBytes_Default(t *testing.T) {
t.Setenv("ZERFOO_ARENA_SIZE_GB", "")
got := arenaSizeBytes(log.Nop())
want := int64(defaultArenaSizeGB) * 1024 * 1024 * 1024
if got != want {
t.Fatalf("default arena size: got %d, want %d", got, want)
}
}

func TestArenaSizeBytes_EnvOverride(t *testing.T) {
tests := []struct {
name string
env string
wantG int64
}{
{"min", "1", 1},
{"training-typical", "32", 32},
{"max", "128", 128},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("ZERFOO_ARENA_SIZE_GB", tt.env)
got := arenaSizeBytes(log.Nop())
want := tt.wantG * 1024 * 1024 * 1024
if got != want {
t.Fatalf("env=%q: got %d, want %d", tt.env, got, want)
}
})
}
}

func TestArenaSizeBytes_InvalidFallsBackToDefault(t *testing.T) {
tests := []struct {
name string
env string
}{
{"non-integer", "lots"},
{"below-min", "0"},
{"above-max", "256"},
{"negative", "-5"},
}
wantDefault := int64(defaultArenaSizeGB) * 1024 * 1024 * 1024
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Setenv("ZERFOO_ARENA_SIZE_GB", tt.env)
got := arenaSizeBytes(log.Nop())
if got != wantDefault {
t.Fatalf("env=%q: got %d, want default %d", tt.env, got, wantDefault)
}
})
}
}
53 changes: 48 additions & 5 deletions compute/gpu_engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"os"
"strconv"
"sync/atomic"
"unsafe"

Expand All @@ -16,6 +17,45 @@ import (
"github.com/zerfoo/ztensor/tensor"
)

// defaultArenaSizeGB is the per-GPUEngine arena capacity when the
// ZERFOO_ARENA_SIZE_GB env var is unset. Sized for single-pass inference
// on typical 1-7B LLMs; larger training workloads (e.g. multi-scale
// walk-forward with hundreds of batches) should raise this via env var
// to keep per-step intermediates inside the arena and avoid spill to
// the unbounded MemPool fallback.
const defaultArenaSizeGB = 2

// minArenaSizeGB and maxArenaSizeGB bound user-supplied values so a typo
// can't request a TB-sized arena or a zero-size one.
const (
minArenaSizeGB = 1
maxArenaSizeGB = 128
)

// arenaSizeBytes resolves the arena capacity in bytes. ZERFOO_ARENA_SIZE_GB,
// if set to an integer in [minArenaSizeGB, maxArenaSizeGB], overrides the
// default. Invalid / out-of-range values are logged and ignored.
func arenaSizeBytes(l log.Logger) int64 {
gb := int64(defaultArenaSizeGB)
if raw := os.Getenv("ZERFOO_ARENA_SIZE_GB"); raw != "" {
parsed, err := strconv.ParseInt(raw, 10, 64)
switch {
case err != nil:
l.Warn("ZERFOO_ARENA_SIZE_GB is not an integer; using default",
"value", raw, "default", fmt.Sprintf("%d", defaultArenaSizeGB))
case parsed < minArenaSizeGB || parsed > maxArenaSizeGB:
l.Warn("ZERFOO_ARENA_SIZE_GB out of range; using default",
"value", fmt.Sprintf("%d", parsed),
"min", fmt.Sprintf("%d", minArenaSizeGB),
"max", fmt.Sprintf("%d", maxArenaSizeGB),
"default", fmt.Sprintf("%d", defaultArenaSizeGB))
default:
gb = parsed
}
}
return gb * 1024 * 1024 * 1024
}

// DType selects the compute precision for GPU operations.
type DType int

Expand Down Expand Up @@ -169,11 +209,14 @@ func NewGPUEngine[T tensor.Numeric](ops numeric.Arithmetic[T], deviceID ...int)
fallbackPool := cuda.NewMemPool()
cuda.SetDefaultMemPool(fallbackPool)

// Arena pool: 2GB pre-allocated region for per-inference intermediates.
// On DGX Spark with 128GB unified memory, this is a small fraction.
// Falls back to MemPool if arena is exhausted.
const arenaSize = 2 * 1024 * 1024 * 1024 // 2 GB
arenaPool, err := gpuapi.NewCUDAArenaPool(dev, arenaSize, fallbackPool)
// Arena pool: pre-allocated region for per-inference / per-step
// intermediates. Defaults to 2GB (sized for 1-7B LLM inference).
// Override via ZERFOO_ARENA_SIZE_GB for larger training workloads
// whose per-step working set would otherwise spill to the unbounded
// MemPool fallback and leak through StepScope.Close(). On DGX Spark
// with 128GB unified memory, sizes up to 128GB are valid.
arenaSize := arenaSizeBytes(l)
arenaPool, err := gpuapi.NewCUDAArenaPool(dev, int(arenaSize), fallbackPool)
if err == nil {
cuda.SetDefaultArenaPool(arenaPool.Inner())
}
Expand Down
Loading