Skip to content

Conversation

@jonahsamost
Copy link

Speed up mainly comes from warp shuffling and synchronization over the H dimension. Also store some variables to registers that get reused. The original backwards kernel had some bugs that are fixed in the optimized version.

Speedup is about 2.5x - 5x on forward and 4x - 13x on backward.

This was tested on an L4.

Forward Benchmarks

Config Original Optimized Speedup
B=512, T=64, H=128 0.306ms 0.060ms 5.10x
B=512, T=64, H=256 0.741ms 0.290ms 2.55x
B=512, T=64, H=384 1.251ms 0.445ms 2.81x
B=512, T=96, H=256 1.045ms 0.446ms 2.34x
B=768, T=64, H=256 1.043ms 0.447ms 2.34x
B=768, T=64, H=512 2.539ms 0.866ms 2.93x
B=1024, T=64, H=256 1.390ms 0.590ms 2.36x
B=1024, T=64, H=512 3.360ms 1.151ms 2.92x
B=1024, T=96, H=384 3.645ms 1.295ms 2.81x
B=1024, T=128, H=256 2.929ms 1.161ms 2.52x
B=1536, T=64, H=512 5.205ms 1.716ms 3.03x
B=2048, T=64, H=256 2.921ms 1.160ms 2.52x
B=2048, T=64, H=512 6.859ms 2.282ms 3.01x
B=2048, T=96, H=512 10.102ms 3.414ms 2.96x
B=1024, T=91, H=384 3.440ms 1.234ms 2.79x
B=1536, T=77, H=512 6.096ms 2.055ms 2.97x

Backward Benchmarks

Config Original Optimized Speedup
B=512, T=64, H=128 1.021ms 0.279ms 3.67x
B=512, T=64, H=256 3.829ms 0.598ms 6.40x
B=512, T=64, H=384 8.470ms 0.890ms 9.51x
B=512, T=96, H=256 5.797ms 0.879ms 6.59x
B=768, T=64, H=256 5.753ms 0.885ms 6.50x
B=768, T=64, H=512 22.268ms 1.749ms 12.74x
B=1024, T=64, H=256 7.742ms 1.175ms 6.59x
B=1024, T=64, H=512 29.729ms 2.361ms 12.59x
B=1024, T=96, H=384 25.680ms 2.629ms 9.77x
B=1024, T=128, H=256 15.893ms 2.326ms 6.83x
B=1536, T=64, H=512 44.865ms 3.497ms 12.83x
B=2048, T=64, H=256 15.880ms 2.310ms 6.88x
B=2048, T=64, H=512 60.022ms 4.701ms 12.77x
B=2048, T=96, H=512 90.278ms 7.036ms 12.83x
B=512, T=128, H=512 30.235ms 2.347ms 12.88x
B=1024, T=91, H=384 24.833ms 2.496ms 9.95x
B=1536, T=77, H=512 54.873ms 4.222ms 13.00x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant