Fix complex support for BatchedDot in C and JAX backends#1909
Open
ayulockedin wants to merge 1 commit intopymc-devs:mainfrom
Open
Fix complex support for BatchedDot in C and JAX backends#1909ayulockedin wants to merge 1 commit intopymc-devs:mainfrom
ayulockedin wants to merge 1 commit intopymc-devs:mainfrom
Conversation
Resolves pymc-devs#1849. Refactors the BatchedDot C-backend template to accept raw float arrays for alpha and beta, bypassing strict struct type mismatches in cgemm/zgemm. Enables native support for complex64 and complex128. Adds comprehensive tests for complex batched matrix multiplications.
67aabb0 to
e1112b5
Compare
Contributor
Author
|
@jessegrabowski could u take a look at this PR when u have a moment thx :) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR adds native support for
complex64andcomplex128data types to theBatchedDotOp across both the C and JAX backends, resolving theNotImplementedErrorraised during graph compilation mentioned by @jessegrabowski in issue #1849Key technical changes:
make_nodetype-checking inpytensor/tensor/blas.pyto explicitly allow complex types to be built into the graph.alphaandbetaas rawfloat[2]anddouble[2]arrays. This matches the exact memory layout expected by the underlying Fortran ABI forcgemm_andzgemm_.PyArray_DATAtochar*before applying NumPy byte strides. This guarantees safe memory traversal regardless of the underlying primitive byte size.matmulwithout requiring an explicit custom dispatch.test_batched_dot_complex) totests/tensor/test_blas.py. Implemented the modern@matrix multiplication operator in the tests to ensure zeroFutureWarningdeprecations are triggered in the CI pipeline.Related Issue
Checklist
Type of change