Hi, thanks for open-sourcing FlashKDA!
I am currently evaluating it as a backend for fla.ops.kda.chunk_kda.
From the README and API, my understanding is that the current implementation is mainly targeting inference-mode forward execution, and seems to assume K/V head dimension = 128.
I have a couple of questions regarding the roadmap:
- Are there any plans to support K/V head dimensions other than 128?
For example, we have seen some KDA configurations where the per-head K/V dimension is 64 rather than 128. We are not specifically asking only about D=64, but more generally whether FlashKDA is expected to support additional common head dimensions in the future.
If support for other dimensions is planned, would it likely be implemented as separate CUDA/CUTLASS specializations, or as a more general implementation?
- Are there any plans to add backward/autograd support, or is the project intended to remain an inference-only forward backend?
Thanks again for your work, and have a great weekend ahead!
Hi, thanks for open-sourcing FlashKDA!
I am currently evaluating it as a backend for
fla.ops.kda.chunk_kda.From the README and API, my understanding is that the current implementation is mainly targeting inference-mode forward execution, and seems to assume K/V head dimension = 128.
I have a couple of questions regarding the roadmap:
For example, we have seen some KDA configurations where the per-head K/V dimension is 64 rather than 128. We are not specifically asking only about D=64, but more generally whether FlashKDA is expected to support additional common head dimensions in the future.
If support for other dimensions is planned, would it likely be implemented as separate CUDA/CUTLASS specializations, or as a more general implementation?
Thanks again for your work, and have a great weekend ahead!