Skip to content

Fix complex support for BatchedDot in C and JAX backends#1909

Open
ayulockedin wants to merge 1 commit intopymc-devs:mainfrom
ayulockedin:feature/complex-batched-dot
Open

Fix complex support for BatchedDot in C and JAX backends#1909
ayulockedin wants to merge 1 commit intopymc-devs:mainfrom
ayulockedin:feature/complex-batched-dot

Conversation

@ayulockedin
Copy link
Contributor

Description

This PR adds native support for complex64 and complex128 data types to the BatchedDot Op across both the C and JAX backends, resolving the NotImplementedError raised during graph compilation mentioned by @jessegrabowski in issue #1849

Key technical changes:

  • Python Graph Layer: Updated make_node type-checking in pytensor/tensor/blas.py to explicitly allow complex types to be built into the graph.
  • C-Backend ABI Compatibility: Bypassed strict C++ complex struct type mismatches by defining alpha and beta as raw float[2] and double[2] arrays. This matches the exact memory layout expected by the underlying Fortran ABI for cgemm_ and zgemm_.
  • Safe Pointer Arithmetic: Prevented stride-related memory corruption in the C-loop by casting PyArray_DATA to char* before applying NumPy byte strides. This guarantees safe memory traversal regardless of the underlying primitive byte size.
  • JAX Backend: Verified that once the PyTensor Python layer permits complex inputs, execution delegates seamlessly to JAX's native complex matmul without requiring an explicit custom dispatch.
  • Testing: Added comprehensive complex execution tests (test_batched_dot_complex) to tests/tensor/test_blas.py. Implemented the modern @ matrix multiplication operator in the tests to ensure zero FutureWarning deprecations are triggered in the CI pipeline.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

Resolves pymc-devs#1849. Refactors the BatchedDot C-backend template to accept raw float arrays for alpha and beta, bypassing strict struct type mismatches in cgemm/zgemm. Enables native support for complex64 and complex128. Adds comprehensive tests for complex batched matrix multiplications.
@ayulockedin ayulockedin force-pushed the feature/complex-batched-dot branch from 67aabb0 to e1112b5 Compare February 22, 2026 14:35
@ayulockedin
Copy link
Contributor Author

@jessegrabowski could u take a look at this PR when u have a moment thx :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Batched dot doesn't support complex inputs

1 participant