Skip to content

[Question] Roadmap for supporting additional K/V head dimensions and backward kernels #4

@dmm19941210

Description

@dmm19941210

Hi, thanks for open-sourcing FlashKDA!

I am currently evaluating it as a backend for fla.ops.kda.chunk_kda.
From the README and API, my understanding is that the current implementation is mainly targeting inference-mode forward execution, and seems to assume K/V head dimension = 128.

I have a couple of questions regarding the roadmap:

  1. Are there any plans to support K/V head dimensions other than 128?
    For example, we have seen some KDA configurations where the per-head K/V dimension is 64 rather than 128. We are not specifically asking only about D=64, but more generally whether FlashKDA is expected to support additional common head dimensions in the future.
    If support for other dimensions is planned, would it likely be implemented as separate CUDA/CUTLASS specializations, or as a more general implementation?
  2. Are there any plans to add backward/autograd support, or is the project intended to remain an inference-only forward backend?

Thanks again for your work, and have a great weekend ahead!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions