Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions xls/dslx/ir_convert/function_converter_fuzztest_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1397,5 +1397,118 @@ fn f(o: Outer) -> u32 { o.x }
EXPECT_EQ(a_domain.range().max().bits().data(), std::string{'\004'});
}

TEST(FunctionConverterFuzzTestTest, ConstantLargeRangeDomain) {
ImportData import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(R"(
const R = u32:0..u32:100000;
#[fuzz_test(domains = `R`)]
fn f(x: u32) -> u32 { x }
)",
"test_module.x", "test_module", &import_data));

XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
tm.module->GetMemberOrError<FuzzTestFunction>("f"));
ASSERT_NE(ft, nullptr);

Function* f = &ft->fn();

const ConvertOptions convert_options;
PackageConversionData package = MakeConversionData("test_module_package");
PackageData package_data{&package};
FunctionConverter converter(package_data, tm.module, &import_data,
convert_options, /*proc_data=*/nullptr,
/*channel_scope=*/nullptr,
/*is_top=*/true);
XLS_ASSERT_OK(
converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));

auto* ir_fn =
package_data.conversion_info->package->functions().front().get();

absl::Span<const AttributeData> attributes = ir_fn->attributes();
const AttributeData::Argument& arg = attributes[0].args()[0];
const auto& skv = std::get<AttributeData::StringKeyValueArgument>(arg);

xls::PackageInterfaceProto::Function function_proto;
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(skv.second, &function_proto));
ASSERT_EQ(function_proto.parameter_domains_size(), 1);
const auto& domain = function_proto.parameter_domains(0);
ASSERT_TRUE(domain.has_range());

EXPECT_EQ(domain.range().min().bits().bit_count(), 32);
EXPECT_EQ(domain.range().min().bits().data(),
std::string("\000\000\000\000", 4));

EXPECT_EQ(domain.range().max().bits().bit_count(), 32);
// 99999 in little endian bytes: 9F 86 01 00
char expected_max_bytes[] = {static_cast<char>(0x9F), static_cast<char>(0x86),
0x01, 0x00};
EXPECT_EQ(domain.range().max().bits().data(),
std::string(expected_max_bytes, 4));
}

TEST(FunctionConverterFuzzTestTest, ImportedConstantLargeRangeDomain) {
ImportData import_data = CreateImportDataForTest();

// Parse imported module first.
XLS_ASSERT_OK(ParseAndTypecheck(R"(
pub const R = u32:0..u32:100000;
)",
"imported.x", "imported", &import_data)
.status());

// Parse main module that imports 'imported'.
XLS_ASSERT_OK_AND_ASSIGN(
TypecheckedModule tm,
ParseAndTypecheck(R"(
import imported;
#[fuzz_test(domains = `imported::R`)]
fn f(x: u32) -> u32 { x }
)",
"test_module.x", "test_module", &import_data));

XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
tm.module->GetMemberOrError<FuzzTestFunction>("f"));
ASSERT_NE(ft, nullptr);

Function* f = &ft->fn();

const ConvertOptions convert_options;
PackageConversionData package = MakeConversionData("test_module_package");
PackageData package_data{&package};
FunctionConverter converter(package_data, tm.module, &import_data,
convert_options, /*proc_data=*/nullptr,
/*channel_scope=*/nullptr,
/*is_top=*/true);
XLS_ASSERT_OK(
converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));

auto* ir_fn =
package_data.conversion_info->package->functions().front().get();

absl::Span<const AttributeData> attributes = ir_fn->attributes();
const AttributeData::Argument& arg = attributes[0].args()[0];
const auto& skv = std::get<AttributeData::StringKeyValueArgument>(arg);

xls::PackageInterfaceProto::Function function_proto;
ASSERT_TRUE(google::protobuf::TextFormat::ParseFromString(skv.second, &function_proto));
ASSERT_EQ(function_proto.parameter_domains_size(), 1);
const auto& domain = function_proto.parameter_domains(0);
ASSERT_TRUE(domain.has_range());

EXPECT_EQ(domain.range().min().bits().bit_count(), 32);
EXPECT_EQ(domain.range().min().bits().data(),
std::string("\000\000\000\000", 4));

EXPECT_EQ(domain.range().max().bits().bit_count(), 32);
// 99999 in little endian bytes: 9F 86 01 00
char expected_max_bytes[] = {static_cast<char>(0x9F), static_cast<char>(0x86),
0x01, 0x00};
EXPECT_EQ(domain.range().max().bits().data(),
std::string(expected_max_bytes, 4));
}

} // namespace
} // namespace xls::dslx
26 changes: 17 additions & 9 deletions xls/dslx/ir_convert/fuzz_test_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "xls/dslx/frontend/ast_node.h"
#include "xls/dslx/interp_value.h"
#include "xls/dslx/type_system/type.h"
#include "xls/dslx/type_system/type_info.h"
#include "xls/ir/value.h"
#include "xls/ir/xls_ir_interface.pb.h"

Expand Down Expand Up @@ -204,40 +205,47 @@ absl::Status FuzzTestConverter::LowerConstant(
absl::Status FuzzTestConverter::LowerDomainExpr(
const Type* param_type, const Expr* expr,
PackageInterfaceProto::FuzzTestDomain& proto) {
if (expr->kind() == AstNodeKind::kStructInstance) {
ResolvedDomain resolved = ResolveDomainExpression(expr, current_type_info_);

const Expr* resolved_expr = resolved.expr;
TypeInfo* resolved_type_info = resolved.type_info != nullptr
? const_cast<TypeInfo*>(resolved.type_info)
: current_type_info_;
if (resolved_expr->kind() == AstNodeKind::kStructInstance) {
XLS_RET_CHECK(param_type != nullptr && param_type->IsStruct());
const StructInstance* struct_domain =
absl::down_cast<const StructInstance*>(expr);
const StructType& struct_type = param_type->AsStruct();
return LowerStructInstanceDomain(struct_type, *struct_domain, proto);
}

if (expr->kind() == AstNodeKind::kRange) {
if (resolved_expr->kind() == AstNodeKind::kRange) {
// Ranges get expanded into arrays by the constexpr evaluator, so if you
// have a range of u32:0..u32:FFFFFFFF, it will try to turn it into an array
// of 2^32 elements, which fills memory. So for ranges we perform the
// lowering directly from the AST, without turning to InterpValue.
const Range* range_node = absl::down_cast<const Range*>(expr);
return LowerRangeExpr(range_node, proto);
const Range* range_node = absl::down_cast<const Range*>(resolved_expr);
return LowerRangeExpr(range_node, resolved_type_info, proto);
}
XLS_ASSIGN_OR_RETURN(
InterpValue const_value,
ConstexprEvaluator::EvaluateToValue(import_data_, current_type_info_,
ConstexprEvaluator::EvaluateToValue(import_data_, resolved_type_info,
/*warning_collector=*/nullptr,
/*bindings=*/{}, expr));
/*bindings=*/{}, resolved_expr));
return LowerConstant(param_type, const_value, proto);
}

absl::Status FuzzTestConverter::LowerRangeExpr(
const Range* range_node, PackageInterfaceProto::FuzzTestDomain& proto) {
const Range* range_node, TypeInfo* type_info,
PackageInterfaceProto::FuzzTestDomain& proto) {
XLS_ASSIGN_OR_RETURN(InterpValue min_val,
ConstexprEvaluator::EvaluateToValue(
import_data_, current_type_info_,
import_data_, type_info,
/*warning_collector=*/nullptr,
/*bindings=*/{}, range_node->start()));
XLS_ASSIGN_OR_RETURN(
InterpValue max_val,
ConstexprEvaluator::EvaluateToValue(import_data_, current_type_info_,
ConstexprEvaluator::EvaluateToValue(import_data_, type_info,
/*warning_collector=*/nullptr,
/*bindings=*/{}, range_node->end()));
if (!range_node->inclusive_end()) {
Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/ir_convert/fuzz_test_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class FuzzTestConverter {
const StructType& struct_type, const StructInstance& struct_domain,
PackageInterfaceProto::FuzzTestDomain& proto);

absl::Status LowerRangeExpr(const Range* range_node,
absl::Status LowerRangeExpr(const Range* range_node, TypeInfo* type_info,
PackageInterfaceProto::FuzzTestDomain& proto);
// Main entry point
absl::Status LowerDomainExpr(const Type* param_type, const Expr* expr,
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/type_system/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ cc_library(
"//xls/dslx/frontend:ast_utils",
"//xls/dslx/frontend:module",
"//xls/dslx/frontend:pos",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
42 changes: 42 additions & 0 deletions xls/dslx/type_system/type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <variant>
#include <vector>

#include "absl/base/casts.h"
#include "absl/container/btree_set.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
Expand Down Expand Up @@ -1180,4 +1181,45 @@ const FileTable& TypeInfo::file_table() const { return *module_->file_table(); }

FileTable& TypeInfo::file_table() { return *module_->file_table(); }

ResolvedDomain ResolveDomainExpression(const Expr* domain,
const TypeInfo* type_info) {
const TypeInfo* current_ti = type_info;
while (true) {
if (domain->kind() == AstNodeKind::kNameRef) {
const NameRef* name_ref = absl::down_cast<const NameRef*>(domain);
if (std::holds_alternative<const NameDef*>(name_ref->name_def())) {
const NameDef* name_def =
std::get<const NameDef*>(name_ref->name_def());
if (const ConstantDef* const_def =
dynamic_cast<const ConstantDef*>(name_def->definer())) {
domain = const_def->value();
continue;
}
}
}
if (domain->kind() == AstNodeKind::kColonRef && current_ti != nullptr) {
const ColonRef* colon_ref = absl::down_cast<const ColonRef*>(domain);
std::optional<ImportSubject> import_subject =
colon_ref->ResolveImportSubject();
if (import_subject.has_value()) {
std::optional<const ImportedInfo*> imported_info =
current_ti->GetImported(*import_subject);
if (imported_info.has_value()) {
Module* imported_module = (*imported_info)->module;
const TypeInfo* imported_type_info = (*imported_info)->type_info;
auto const_def_or =
imported_module->GetMember<ConstantDef>(colon_ref->attr());
if (const_def_or.has_value()) {
domain = (*const_def_or)->value();
current_ti = imported_type_info;
continue;
}
}
}
}
break;
}
return {domain, current_ti};
}

} // namespace xls::dslx
13 changes: 13 additions & 0 deletions xls/dslx/type_system/type_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,19 @@ inline absl::StatusOr<T*> TypeInfo::GetItemAs(const AstNode* key) const {
return target;
}

struct ResolvedDomain {
const Expr* expr;
const TypeInfo* type_info;
};

// Resolves a domain expression recursively. If the expression is a NameRef that
// references a ConstantDef, it returns the resolved constant value expression.
// If it is a ColonRef that references an imported constant, it resolves it to
// the imported constant value expression using the provided TypeInfo.
// Otherwise, it returns the expression itself.
ResolvedDomain ResolveDomainExpression(const Expr* domain,
const TypeInfo* type_info);

} // namespace xls::dslx

#endif // XLS_DSLX_TYPE_SYSTEM_TYPE_INFO_H_
47 changes: 32 additions & 15 deletions xls/dslx/type_system_v2/validate_concrete_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ class TypeValidator : public AstNodeVisitorWithDefault {

absl::Status ValidateStructDomain(const StructInstance* domain,
const Type* param_type,
std::string_view param_str) {
std::string_view param_str,
const TypeInfo* type_info) {
if (!param_type->IsStruct()) {
return TypeInferenceErrorStatus(
domain->span(), param_type,
Expand All @@ -541,15 +542,17 @@ class TypeValidator : public AstNodeVisitorWithDefault {
// ValidateStructInstanceMemberNames, so we can assume they exist here.
XLS_RET_CHECK(it != formal_members.end()) << "Extraneous member " << name;
const Type* formal_member_type = it->second;
XLS_RETURN_IF_ERROR(ValidateFuzzTestDomain(
actual_member, formal_member_type, formal_member_type->ToString()));
XLS_RETURN_IF_ERROR(
ValidateFuzzTestDomain(actual_member, formal_member_type,
formal_member_type->ToString(), type_info));
}
return absl::OkStatus();
}

absl::Status ValidateTupleDomain(const XlsTuple* domain,
const Type* param_type,
std::string_view param_str) {
std::string_view param_str,
const TypeInfo* type_info) {
if (domain->members().empty()) {
// Empty domain for this parameter; this is considered an "Arbitrary"
// domain and always matches.
Expand All @@ -572,8 +575,9 @@ class TypeValidator : public AstNodeVisitorWithDefault {
}
for (int i = 0; i < domain->members().size(); ++i) {
const Type& member_type = tuple_type->GetMemberType(i);
XLS_RETURN_IF_ERROR(ValidateFuzzTestDomain(
domain->members()[i], &member_type, member_type.ToString()));
XLS_RETURN_IF_ERROR(
ValidateFuzzTestDomain(domain->members()[i], &member_type,
member_type.ToString(), type_info));
}
return absl::OkStatus();
}
Expand All @@ -582,19 +586,32 @@ class TypeValidator : public AstNodeVisitorWithDefault {
// function parameter type. Returns an error if not compatible.
absl::Status ValidateFuzzTestDomain(const Expr* domain,
const Type* param_type,
std::string_view param_str) {
if (domain->kind() == AstNodeKind::kStructInstance) {
std::string_view param_str,
const TypeInfo* type_info = nullptr) {
const TypeInfo* current_ti = type_info != nullptr ? type_info : &ti_;

ResolvedDomain resolved = ResolveDomainExpression(domain, current_ti);

const Expr* resolved_domain = resolved.expr;
const TypeInfo* resolved_ti =
resolved.type_info != nullptr ? resolved.type_info : current_ti;

if (resolved_domain->kind() == AstNodeKind::kStructInstance) {
return ValidateStructDomain(
absl::down_cast<const StructInstance*>(domain), param_type,
param_str);
absl::down_cast<const StructInstance*>(resolved_domain), param_type,
param_str, resolved_ti);
}
if (domain->kind() == AstNodeKind::kXlsTuple) {
return ValidateTupleDomain(absl::down_cast<const XlsTuple*>(domain),
param_type, param_str);
if (resolved_domain->kind() == AstNodeKind::kXlsTuple) {
return ValidateTupleDomain(
absl::down_cast<const XlsTuple*>(resolved_domain), param_type,
param_str, resolved_ti);
}

std::optional<Type*> maybe_domain_type = ti_.GetItem(domain);
XLS_RET_CHECK(maybe_domain_type.has_value());
std::optional<Type*> maybe_domain_type =
resolved_ti->GetItem(resolved_domain);
XLS_RET_CHECK(maybe_domain_type.has_value())
<< "No type info for domain: " << resolved_domain->ToString()
<< " in module " << resolved_ti->module()->name();
const Type* domain_type = *maybe_domain_type;

return ValidateFuzzTestDomainType(domain_type, param_type, domain->span(),
Expand Down
Loading
Loading