Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 57 additions & 6 deletions xls/jit/jit_wrapper_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from xls.ir import xls_ir_interface_pb2 as ir_interface_pb2
from xls.ir import xls_type_pb2 as type_pb2
from xls.ir import xls_value_pb2 as value_pb2
from xls.jit import aot_entrypoint_pb2


Expand Down Expand Up @@ -241,7 +242,9 @@ def to_c_type(t: type_pb2.TypeProto) -> Optional[str]:
if c_type is not None:
return c_type
the_type = t.type_enum
if the_type == type_pb2.TypeProto.TUPLE:
if the_type == type_pb2.TypeProto.BITS:
return "xls::Value"
elif the_type == type_pb2.TypeProto.TUPLE:
elems = [to_c_type(e) for e in t.tuple_elements]
if any(e is None for e in elems):
return None
Expand Down Expand Up @@ -312,6 +315,46 @@ def can_use_uint64_range(
return False


def value_proto_to_cpp_literal(
t: type_pb2.TypeProto, v: value_pb2.ValueProto
) -> str:
"""Converts a ValueProto to its C++ representation based on the given TypeProto."""
if v.HasField("bits"):
c_type = to_specialized(t, int_only=True)
val = extract_int_from_bytes(v.bits.data)
if c_type is not None:
return str(val)
# Fallback to big-endian bytes vector construction for wide bits
bytes_str = ", ".join(f"0x{b:02x}" for b in v.bits.data[::-1])
return (
f"xls::Value(xls::Bits::FromBytes(std::vector<uint8_t>{{{bytes_str}}},"
f" {t.bit_count}))"
)

if v.HasField("tuple"):
if len(t.tuple_elements) != len(v.tuple.elements):
raise app.UsageError(
"Tuple element count mismatch for C++ literal conversion"
)
elems = [
value_proto_to_cpp_literal(te, ve)
for te, ve in zip(t.tuple_elements, v.tuple.elements)
]
return f"std::make_tuple({', '.join(elems)})"

if v.HasField("array"):
elems = [
value_proto_to_cpp_literal(t.array_element, ve)
for ve in v.array.elements
]
c_type = to_c_type(t)
return f"{c_type}{{{', '.join(elems)}}}"

raise app.UsageError(
f"Unsupported ValueProto for C++ literal conversion: {v}"
)


def to_domain(
t: type_pb2.TypeProto,
d: Optional[ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain],
Expand Down Expand Up @@ -349,7 +392,14 @@ def to_domain(
if t.type_enum == type_pb2.TypeProto.BITS:
c_type = to_specialized(t, int_only=True)
if c_type is None:
return None
# Support arbitrary domain for wide bits (>64) represented as xls::Value
byte_count = (t.bit_count + 7) // 8
return (
f"fuzztest::Map([](const std::array<uint8_t, {byte_count}>& bytes)"
" { return xls::Value(xls::Bits::FromBytes(bytes,"
f" {t.bit_count})); }},"
f" fuzztest::ArrayOf<{byte_count}>(fuzztest::Arbitrary<uint8_t>()))"
)
if t.bit_count in (8, 16, 32, 64):
return f"fuzztest::Arbitrary<{c_type}>()"
else:
Expand Down Expand Up @@ -378,15 +428,13 @@ def to_domain(
return f"fuzztest::InRange<{cpp_type}>({min_val}, {max_val})"

if d.HasField("element_of"):
c_type = to_specialized(t, int_only=True)
c_type = to_c_type(t)
if c_type is None:
raise app.UsageError(
"ElementOf domain only supported for specializable bits types in"
" this CL"
)
vals = [
str(extract_int_from_bytes(v.bits.data)) for v in d.element_of.values
]
vals = [value_proto_to_cpp_literal(t, v) for v in d.element_of.values]
return f"fuzztest::ElementOf(std::vector<{c_type}>{{{', '.join(vals)}}})"

if d.HasField("tuple"):
Expand All @@ -405,6 +453,9 @@ def to_domain(
def to_value_conversion(t: type_pb2.TypeProto, expr: str) -> str:
"""Generates C++ snippet to convert a native type to xls::Value."""
if t.type_enum == type_pb2.TypeProto.BITS:
c_type = to_specialized(t, int_only=True)
if c_type is None:
return expr
return f"xls::Value(xls::UBits({expr}, {t.bit_count}))"
elif t.type_enum == type_pb2.TypeProto.TUPLE:
elems = []
Expand Down
142 changes: 129 additions & 13 deletions xls/jit/jit_wrapper_generator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from xls.common import runfiles
from xls.ir import xls_ir_interface_pb2 as ir_interface_pb2
from xls.ir import xls_type_pb2 as type_pb2
from xls.ir import xls_value_pb2 as value_pb2
from xls.jit import jit_wrapper_generator


Expand Down Expand Up @@ -134,24 +135,28 @@ def test_nested_tuple(self):
'std::tuple<std::tuple<uint8_t, uint32_t>, uint64_t>',
)

def test_unsupported_bits(self):
def test_wide_bits(self):
u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
self.assertIsNone(jit_wrapper_generator.to_c_type(u128))
self.assertEqual(jit_wrapper_generator.to_c_type(u128), 'xls::Value')

def test_unsupported_array(self):
def test_tuple_with_wide_bits(self):
u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
a128 = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.ARRAY, array_size=4, array_element=u128
tup = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, u128]
)
self.assertEqual(
jit_wrapper_generator.to_c_type(tup), 'std::tuple<uint32_t, xls::Value>'
)
self.assertIsNone(jit_wrapper_generator.to_c_type(a128))

def test_unsupported_tuple(self):
u8 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=8)
def test_wide_array(self):
u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
tup = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u8, u128]
a128 = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.ARRAY, array_size=4, array_element=u128
)
self.assertEqual(
jit_wrapper_generator.to_c_type(a128), 'std::array<xls::Value, 4>'
)
self.assertIsNone(jit_wrapper_generator.to_c_type(tup))


class JitWrapperGeneratorWrappedToFuzztestTest(absltest.TestCase):
Expand Down Expand Up @@ -403,7 +408,6 @@ def test_function_unsupported_param_fallback(self):
)
self.assertLen(prop_func.params, 1)
self.assertEqual(prop_func.params[0].cpp_type, 'xls::Value')
self.assertIsNone(prop_func.params[0].conversion_snippet)


class JitWrapperGeneratorToValueConversionTest(absltest.TestCase):
Expand Down Expand Up @@ -784,6 +788,72 @@ def test_render_fuzztest_default_domain(self):
self.assertIn('fuzztest::Arbitrary<uint32_t>()', rendered_code)


class JitWrapperGeneratorValueProtoToCppLiteralTest(absltest.TestCase):

def test_bits_specialized(self):
u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
v = value_pb2.ValueProto()
v.bits.bit_count = 32
v.bits.data = b'\x2a\x00\x00\x00' # 42 in little-endian
self.assertEqual(
jit_wrapper_generator.value_proto_to_cpp_literal(u32, v), '42'
)

def test_bits_wide(self):
u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
v = value_pb2.ValueProto()
v.bits.bit_count = 128
# 99999 in little endian 16 bytes: 9f 86 01 00 ...
v.bits.data = b'\x9f\x86\x01\x00' + b'\x00' * 12

# Expected big endian hex bytes list: 0x00, ..., 0x01, 0x86, 0x9f
expected_bytes = ', '.join(['0x00'] * 13 + ['0x01', '0x86', '0x9f'])
expected_str = (
f'xls::Value(xls::Bits::FromBytes(std::vector<uint8_t>{{{expected_bytes}}},'
' 128))'
)
self.assertEqual(
jit_wrapper_generator.value_proto_to_cpp_literal(u128, v), expected_str
)

def test_tuple(self):
u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
u8 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=8)
tup = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, u8]
)

v = value_pb2.ValueProto()
v1 = v.tuple.elements.add()
v1.bits.bit_count = 32
v1.bits.data = b'\x2a\x00\x00\x00' # 42
v2 = v.tuple.elements.add()
v2.bits.bit_count = 8
v2.bits.data = b'\x01' # 1

self.assertEqual(
jit_wrapper_generator.value_proto_to_cpp_literal(tup, v),
'std::make_tuple(42, 1)',
)

def test_array(self):
u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
arr = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.ARRAY, array_size=3, array_element=u32
)

v = value_pb2.ValueProto()
for val in (1, 2, 3):
ve = v.array.elements.add()
ve.bits.bit_count = 32
ve.bits.data = val.to_bytes(4, 'little')

self.assertEqual(
jit_wrapper_generator.value_proto_to_cpp_literal(arr, v),
'std::array<uint32_t, 3>{1, 2, 3}',
)


class JitWrapperGeneratorToDomainTest(absltest.TestCase):

def test_extract_int_from_bytes(self):
Expand All @@ -810,7 +880,12 @@ def test_bits_domain_non_power_of_2(self):

def test_bits_domain_too_wide(self):
u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
self.assertIsNone(jit_wrapper_generator.to_domain(u128, None))
self.assertEqual(
jit_wrapper_generator.to_domain(u128, None),
'fuzztest::Map([](const std::array<uint8_t, 16>& bytes) { '
'return xls::Value(xls::Bits::FromBytes(bytes, 128)); }, '
'fuzztest::ArrayOf<16>(fuzztest::Arbitrary<uint8_t>()))',
)

def test_range_domain(self):
u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
Expand Down Expand Up @@ -850,6 +925,35 @@ def test_element_of_domain(self):
'fuzztest::ElementOf(std::vector<uint32_t>{1, 2})',
)

def test_element_of_tuple_domain(self):
u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
tup = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, u32]
)
d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
# First element of elements_of: (1, 2)
e1 = d.element_of.values.add()
e1_member0 = e1.tuple.elements.add()
e1_member0.bits.bit_count = 32
e1_member0.bits.data = b'\x01\x00\x00\x00'
e1_member1 = e1.tuple.elements.add()
e1_member1.bits.bit_count = 32
e1_member1.bits.data = b'\x02\x00\x00\x00'
# Second element of elements_of: (3, 4)
e2 = d.element_of.values.add()
e2_member0 = e2.tuple.elements.add()
e2_member0.bits.bit_count = 32
e2_member0.bits.data = b'\x03\x00\x00\x00'
e2_member1 = e2.tuple.elements.add()
e2_member1.bits.bit_count = 32
e2_member1.bits.data = b'\x04\x00\x00\x00'

self.assertEqual(
jit_wrapper_generator.to_domain(tup, d),
'fuzztest::ElementOf(std::vector<std::tuple<uint32_t,'
' uint32_t>>{std::make_tuple(1, 2), std::make_tuple(3, 4)})',
)

def test_tuple_domain(self):
u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
tup = type_pb2.TypeProto(
Expand Down Expand Up @@ -937,6 +1041,18 @@ def test_unsupported_domain_raises(self):
):
jit_wrapper_generator.to_domain(tup, d)

def test_arbitrary_array_of_wide_bits(self):
u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
arr = type_pb2.TypeProto(
type_enum=type_pb2.TypeProto.ARRAY, array_size=3, array_element=u128
)
self.assertEqual(
jit_wrapper_generator.to_domain(arr, None),
'fuzztest::ArrayOf<3>(fuzztest::Map([](const std::array<uint8_t, 16>&'
' bytes) { return xls::Value(xls::Bits::FromBytes(bytes, 128)); },'
' fuzztest::ArrayOf<16>(fuzztest::Arbitrary<uint8_t>())))',
)


class JitWrapperGeneratorToParamTest(absltest.TestCase):

Expand Down
47 changes: 47 additions & 0 deletions xls/tests/fuzz_test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,50 @@ dslx_fuzz_test(
library = ":array_tests_dslx",
test_function = "tuple_with_big_array",
)

dslx_fuzz_test(
name = "array_of_tuples_with_wide_bits_fuzz_test",
library = ":array_tests_dslx",
test_function = "array_of_tuples_with_wide_bits",
)

dslx_fuzz_test(
name = "nested_big_array_fuzz_test",
library = ":array_tests_dslx",
test_function = "nested_big_array",
)

xls_dslx_library(
name = "struct_tests_dslx",
srcs = ["struct_tests.x"],
)

dslx_fuzz_test(
name = "arbitrary_struct_fuzz_test",
library = ":struct_tests_dslx",
test_function = "arbitrary_struct",
)

dslx_fuzz_test(
name = "struct_range_fuzz_test",
library = ":struct_tests_dslx",
test_function = "struct_range",
)

dslx_fuzz_test(
name = "struct_element_of_fuzz_test",
library = ":struct_tests_dslx",
test_function = "struct_element_of",
)

dslx_fuzz_test(
name = "struct_arbitrary_field_fuzz_test",
library = ":struct_tests_dslx",
test_function = "struct_arbitrary_field",
)

dslx_fuzz_test(
name = "struct_with_wide_bits_fuzz_test",
library = ":struct_tests_dslx",
test_function = "struct_with_wide_bits",
)
8 changes: 4 additions & 4 deletions xls/tests/fuzz_test/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ information.
The following domain specifications are supported in the `domains` argument of
the `#[fuzz_test]` attribute:

- **Arbitrary**: `()` - Explores the full range of the type.
- **Arbitrary**: `()` - Explores the full range of the type. For array
parameters, explores the full range of the base type of the array,
constrained to the length of the array.
- **Numeric Range**: `Type:min..max` - Explores values in the range `[min,
max)`. Example: `u32:0..100`. "End-inclusive" ranges work too, e.g.,
`Type:min..=max` explores values in the range `[min, max]`.
- **Element Of**: `[val1, val2, ...]` - Explores only the listed values.
Example: `[u32:5, 10, 15]`.
- **Tuples**: `(Domain1, Domain2, ...)` - For tuple parameters. Example:
`(u32:0..10, [u8:1, 2])`. The parentheses are required.
- **Structs**: `StructName { field_name: domain }` - For struct parameters.

## Known Limitations

- **Array parameters**: are not yet supported.
- **Struct parameters**: (e.g., `StructName { field: Domain, ... }`) are not
yet supported.
- Fuzzing is currently limited to types that can be mapped to native C++ types
up to 64 bits for full specialization. Larger types fallback to `xls::Value`
and may have limited mutation capabilities.
10 changes: 10 additions & 0 deletions xls/tests/fuzz_test/array_tests.x
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,13 @@ fn big_array(x: uN[128][3]) -> bool {
fn tuple_with_big_array(x: (uN[128][2], u32)) -> bool {
true
}

#[fuzz_test(domains=`()`)]
fn array_of_tuples_with_wide_bits(x: (uN[128], u32)[2]) -> bool {
true
}

#[fuzz_test(domains=`()`)]
fn nested_big_array(x: uN[128][2][3]) -> bool {
true
}
Loading
Loading