Assignment for the LLM course on FlashAttention
Create and activate the conda environment:
conda env create -f environment.yml
conda activate flashattention.
├── online_softmax/
│ ├── online_softmax.py # Online softmax implementation (Triton + PyTorch)
│ └── fused_softmax.py # Fused softmax with matrix multiplication
├── benchmarking/
│ ├── bench_softmax.py # Benchmark for online softmax
│ └── bench_fused.py # Benchmark for fused softmax
├── tests/
│ ├── test_online_softmax.py
│ └── test_fused_softmax.py
└── environment.yml
pytest tests/ -vpython benchmarking/bench_softmax.py
python benchmarking/bench_fused.pyResults are saved to outputs/ directory.
- The fused softmax kernel uses
tl.dot()which requires all dimensions to be >= 16 (tensor core constraint) - Block sizes must be >= 16 for the fused softmax implementation
- Numerical tolerance for fused softmax tests is 1e-3 (vs 1e-5 for simple softmax) due to error accumulation from online algorithm + TF32 tensor core operations