Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 67 additions & 39 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,10 +1368,9 @@ def perform(self, node, inp, out):

def c_support_code(self, **kwargs):
batch_gemm_defn = """
template<typename dtype>
bool batch_gemm(void (*gemm)(char*, char*, const int*, const int*, const int*, const dtype*, const dtype*, const int*, const dtype*, const int*, const dtype*, dtype*, const int*),
int type_size, PyArrayObject* xs, PyArrayObject* ys,
PyArrayObject* zs) {
template<typename dtype, typename GEMM>
bool batch_gemm(GEMM gemm, int type_size, PyArrayObject* xs, PyArrayObject* ys,
PyArrayObject* zs, const dtype* a, const dtype* b) {
npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs);
npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys);
npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs);
Expand All @@ -1396,53 +1395,50 @@ def c_support_code(self, **kwargs):
return 1;
}

/* encode the stride structure of _x,_y,_z into a single integer. */
int unit = 0;
unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8;
unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4;
unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0;

/* create appropriate strides for malformed matrices that are row or column
* vectors, or empty matrices.
* In that case, the value of the stride does not really matter, but
* some versions of BLAS insist that:
* - they are not smaller than the number of elements in the array,
* - they are not 0.
*/
int sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1);
int sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1);
int sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1);
int sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1);
int sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1);
int sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1);

dtype* x = (dtype*)PyArray_DATA(xs);
dtype* y = (dtype*)PyArray_DATA(ys);
dtype* z = (dtype*)PyArray_DATA(zs);
// Cast to char* to ensure byte-level pointer arithmetic is perfectly safe
char* x_p = (char*)PyArray_DATA(xs);
char* y_p = (char*)PyArray_DATA(ys);
char* z_p = (char*)PyArray_DATA(zs);

dtype a = 1.0;
dtype b = 0.0;
char N = 'N';
char T = 'T';
int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2];

// loop over batch axis
for (int i = 0; i < Nz[0]; i++) {
// Inside the loop, cast back to the numeric type BLAS expects
dtype* x = (dtype*)x_p;
dtype* y = (dtype*)y_p;
dtype* z = (dtype*)z_p;

switch(unit)
{
case 0x000: gemm(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break;
case 0x100: gemm(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break;
case 0x010: gemm(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break;
case 0x110: gemm(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break;
case 0x001: gemm(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break;
case 0x101: gemm(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break;
case 0x011: gemm(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break;
case 0x111: gemm(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break;
case 0x000: gemm(&N, &N, &Nz2, &Nz1, &Nx2, a, y, &sy_1, x, &sx_1, b, z, &sz_1); break;
case 0x100: gemm(&N, &T, &Nz2, &Nz1, &Nx2, a, y, &sy_1, x, &sx_2, b, z, &sz_1); break;
case 0x010: gemm(&T, &N, &Nz2, &Nz1, &Nx2, a, y, &sy_2, x, &sx_1, b, z, &sz_1); break;
case 0x110: gemm(&T, &T, &Nz2, &Nz1, &Nx2, a, y, &sy_2, x, &sx_2, b, z, &sz_1); break;
case 0x001: gemm(&T, &T, &Nz1, &Nz2, &Nx2, a, x, &sx_1, y, &sy_1, b, z, &sz_2); break;
case 0x101: gemm(&N, &T, &Nz1, &Nz2, &Nx2, a, x, &sx_2, y, &sy_1, b, z, &sz_2); break;
case 0x011: gemm(&T, &N, &Nz1, &Nz2, &Nx2, a, x, &sx_1, y, &sy_2, b, z, &sz_2); break;
case 0x111: gemm(&N, &N, &Nz1, &Nz2, &Nx2, a, x, &sx_2, y, &sy_2, b, z, &sz_2); break;
default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); return 1;
};
x += Sx[0] / type_size;
y += Sy[0] / type_size;
z += Sz[0] / type_size;

// Increment safely using byte strides
x_p += Sx[0];
y_p += Sy[0];
z_p += Sz[0];
}

return 0;
Expand Down Expand Up @@ -1562,16 +1558,22 @@ def contiguous(var, ndim):
{contiguate}

if ((PyArray_DESCR({_x})->type_num != NPY_DOUBLE)
&& (PyArray_DESCR({_x})->type_num != NPY_FLOAT))
{{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); {fail};}}
&& (PyArray_DESCR({_x})->type_num != NPY_FLOAT)
&& (PyArray_DESCR({_x})->type_num != NPY_CFLOAT)
&& (PyArray_DESCR({_x})->type_num != NPY_CDOUBLE))
{{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not float, double, complex64, or complex128"); {fail};}}

if ((PyArray_DESCR({_y})->type_num != NPY_DOUBLE)
&& (PyArray_DESCR({_y})->type_num != NPY_FLOAT))
{{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); {fail};}}
&& (PyArray_DESCR({_y})->type_num != NPY_FLOAT)
&& (PyArray_DESCR({_y})->type_num != NPY_CFLOAT)
&& (PyArray_DESCR({_y})->type_num != NPY_CDOUBLE))
{{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not float, double, complex64, or complex128"); {fail};}}

if ((PyArray_DESCR({_z})->type_num != NPY_DOUBLE)
&& (PyArray_DESCR({_z})->type_num != NPY_FLOAT))
{{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); {fail};}}
&& (PyArray_DESCR({_z})->type_num != NPY_FLOAT)
&& (PyArray_DESCR({_z})->type_num != NPY_CFLOAT)
&& (PyArray_DESCR({_z})->type_num != NPY_CDOUBLE))
{{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not float, double, complex64, or complex128"); {fail};}}

if ((PyArray_DESCR({_x})->type_num != PyArray_DESCR({_y})->type_num)
||(PyArray_DESCR({_x})->type_num != PyArray_DESCR({_z})->type_num))
Expand All @@ -1580,13 +1582,39 @@ def contiguous(var, ndim):
switch (type_num)
{{
case NPY_FLOAT:
if (batch_gemm<float>(sgemm_, type_size, {_x}, {_y}, {_z})) {{
{fail};
{{
float a = 1.0f; float b = 0.0f;
if (batch_gemm<float>(sgemm_, type_size, {_x}, {_y}, {_z}, &a, &b)) {{
{fail};
}}
}}
break;
case NPY_DOUBLE:
if (batch_gemm<double>(dgemm_, type_size, {_x}, {_y}, {_z})) {{
{fail};
{{
double a = 1.0; double b = 0.0;
if (batch_gemm<double>(dgemm_, type_size, {_x}, {_y}, {_z}, &a, &b)) {{
{fail};
}}
}}
break;
case NPY_CFLOAT:
{{
// Complex numbers are natively just arrays of floats in C
float a[2] = {{1.0f, 0.0f}};
float b[2] = {{0.0f, 0.0f}};
if (batch_gemm<float>(cgemm_, type_size, {_x}, {_y}, {_z}, a, b)) {{
{fail};
}}
}}
break;
case NPY_CDOUBLE:
{{
// Pass arrays to satisfy zgemm's double* requirement
double a[2] = {{1.0, 0.0}};
double b[2] = {{0.0, 0.0}};
if (batch_gemm<double>(zgemm_, type_size, {_x}, {_y}, {_z}, a, b)) {{
{fail};
}}
}}
break;
}}
Expand Down
43 changes: 43 additions & 0 deletions tests/tensor/test_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2501,6 +2501,49 @@ def test_batched_dot():
assert result.shape[0] == first_mat_val.shape[0]


def test_batched_dot_complex():
"""
Validates that BatchedDot correctly compiles and executes complex64
and complex128 inputs across the C and JAX backends. Resolves #1849.
"""
import numpy as np

import pytensor.tensor as pt
from pytensor import function

# Test complex128 (ztensor3)
x = pt.ztensor3("x")
y = pt.ztensor3("y")
z = x @ y # Using modern @ operator avoids the FutureWarning

f = function([x, y], z)

# Generate random complex data
rng = np.random.default_rng(42)
x_val = rng.normal(size=(2, 6, 6)) + 1j * rng.normal(size=(2, 6, 6))
y_val = rng.normal(size=(2, 6, 6)) + 1j * rng.normal(size=(2, 6, 6))

expected = np.matmul(x_val, y_val)
actual = f(x_val, y_val)

np.testing.assert_allclose(actual, expected, rtol=1e-5, atol=1e-8)

# Test complex64 (ctensor3)
x_c = pt.ctensor3("x_c")
y_c = pt.ctensor3("y_c")
z_c = x_c @ y_c # Using modern @ operator

f_c = function([x_c, y_c], z_c)

x_val_c = x_val.astype(np.complex64)
y_val_c = y_val.astype(np.complex64)

expected_c = np.matmul(x_val_c, y_val_c)
actual_c = f_c(x_val_c, y_val_c)

np.testing.assert_allclose(actual_c, expected_c, rtol=1e-4, atol=1e-6)


def test_batched_dot_not_contiguous():
def np_genarray(*_shape):
size = 1
Expand Down
Loading