Skip to content

small fixes#3

Open
ngc92 wants to merge 1 commit intomainfrom
fixes
Open

small fixes#3
ngc92 wants to merge 1 commit intomainfrom
fixes

Conversation

@ngc92
Copy link
Copy Markdown
Contributor

@ngc92 ngc92 commented Feb 20, 2026

inline to prevent ODR violations, and better error messages for dtype mismatch

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR makes small robustness improvements in the quartet2 kernels by avoiding C++ header ODR issues and replacing a brittle dtype assertion with clearer runtime errors.

Changes:

  • Replace assert on input.dtype with explicit TypeErrors for weight/input dtype mismatches (bfloat16 requirement).
  • Mark ptx_type_name variable template specializations as inline constexpr in a CUDA header to prevent ODR violations.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
kernels/python/quartet2/linear.py Improves dtype mismatch handling by raising clearer exceptions instead of asserting.
kernels/csrc/utils.cuh Adds inline to constexpr variable templates to avoid multiple-definition/ODR issues from header inclusion.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +168 to +170
raise TypeError("Weight must be bfloat16. Either set `dtype=torch.bfloat16` or enable autocast`")
elif input.dtype != torch.bfloat16:
raise TypeError("Input must be bfloat16. Either cast input to bfloat16 or enable autocast`")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants