Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xls/dslx/frontend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,9 @@ cc_library(
":token",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/dslx:import_data",
"//xls/dslx/type_system_v2:import_utils",
"//xls/dslx/type_system_v2:type_annotation_utils",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down
70 changes: 50 additions & 20 deletions xls/dslx/frontend/lambda_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
#include "xls/dslx/frontend/module.h"
#include "xls/dslx/frontend/pos.h"
#include "xls/dslx/frontend/token.h"
#include "xls/dslx/import_data.h"
#include "xls/dslx/type_system_v2/import_utils.h"
#include "xls/dslx/type_system_v2/type_annotation_utils.h"

namespace xls::dslx {
namespace {
Expand Down Expand Up @@ -172,8 +175,8 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {

class LambdaRewriter : public AstNodeVisitorWithDefault {
public:
explicit LambdaRewriter(const FileTable& file_table)
: file_table_(file_table) {}
explicit LambdaRewriter(const ImportData& import_data)
: import_data_(import_data) {}

absl::Status HandleLambda(const Lambda* node) override {
XLS_RETURN_IF_ERROR(DefaultHandler(node));
Expand Down Expand Up @@ -237,11 +240,11 @@ class LambdaRewriter : public AstNodeVisitorWithDefault {
}
}

NameDef* struct_nd =
module->Make<NameDef>(span,
absl::Substitute("lambda_capture_struct_at_$0",
span.ToString(file_table_)),
/*definer=*/nullptr);
NameDef* struct_nd = module->Make<NameDef>(
span,
absl::Substitute("lambda_capture_struct_at_$0",
span.ToString(import_data_.file_table())),
/*definer=*/nullptr);
StructDef* full_struct_def =
module->Make<StructDef>(span, struct_nd, struct_parametric_bindings,
struct_members, /*is_public=*/false);
Expand Down Expand Up @@ -428,20 +431,47 @@ class LambdaRewriter : public AstNodeVisitorWithDefault {

XLS_ASSIGN_OR_RETURN(TypeDefinition type_def,
ToTypeDefinition(original_nd->definer()));
TypeRef* lambda_type_ref =
TypeRef* instance_type_ref =
module->Make<TypeRef>(original_nd->span(), type_def);

struct_instance_parametrics.push_back(module->Make<TypeRefTypeAnnotation>(
original_nd->span(), lambda_type_ref, std::vector<ExprOrType>{}));
original_nd->span(), instance_type_ref, std::vector<ExprOrType>{}));
parametric_nds.insert(original_nd);

TypeRef* lambda_type_ref = nullptr;
for (const TypeRefTypeAnnotation* original_type_ref : trtas) {
node_replacements.emplace(
original_type_ref,
module->Make<TypeVariableTypeAnnotation>(
module->Make<NameRef>(original_type_ref->span(),
lambda_struct_nd->identifier(),
lambda_struct_nd),
/*internal=*/true));
XLS_ASSIGN_OR_RETURN(std::optional<StructOrProcRef> struct_or_proc_ref,
GetStructOrProcRef(original_type_ref, import_data_));
TypeAnnotation* replacement;
if (struct_or_proc_ref.has_value()) {
if (lambda_type_ref == nullptr) {
XLS_ASSIGN_OR_RETURN(
AstNode * lambda_type_def,
CloneAst(ToAstNode(type_def),
[lambda_struct_nd](
const AstNode* node, const Module*,
const absl::flat_hash_map<const AstNode*, AstNode*>&)
-> std::optional<AstNode*> {
if (node->kind() == AstNodeKind::kNameDef) {
return lambda_struct_nd;
}
return const_cast<AstNode*>(node);
}));

lambda_type_ref = module->Make<TypeRef>(
original_nd->span(), *ToTypeDefinition(lambda_type_def));
}
replacement = module->Make<TypeRefTypeAnnotation>(
original_type_ref->span(), lambda_type_ref,
original_type_ref->parametrics(),
original_type_ref->instantiator());
} else {
replacement = module->Make<TypeVariableTypeAnnotation>(
module->Make<NameRef>(original_type_ref->span(),
lambda_struct_nd->identifier(),
lambda_struct_nd),
/*internal=*/true);
}
node_replacements.emplace(original_type_ref, replacement);
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -489,13 +519,13 @@ class LambdaRewriter : public AstNodeVisitorWithDefault {
seen.insert(original_name_def);
}

const FileTable& file_table_;
const ImportData& import_data_;
};

} // namespace

absl::Status RewriteLambdas(Module& module, const FileTable& file_table) {
LambdaRewriter visitor(file_table);
absl::Status RewriteLambdas(Module& module, const ImportData& import_data) {
LambdaRewriter visitor(import_data);
return module.Accept(&visitor);
}

Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/frontend/lambda_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#include "absl/status/status.h"
#include "xls/dslx/frontend/module.h"
#include "xls/dslx/frontend/pos.h"
#include "xls/dslx/import_data.h"

namespace xls::dslx {

Expand All @@ -40,7 +40,7 @@ namespace xls::dslx {
// let x = u32:2;
// map(arr, lambda_capture{x: x}.call)
// }
absl::Status RewriteLambdas(Module& module, const FileTable& file_table);
absl::Status RewriteLambdas(Module& module, const ImportData& import_data);

} // namespace xls::dslx

Expand Down
2 changes: 1 addition & 1 deletion xls/dslx/frontend/semantics_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ absl::Status SemanticsAnalysis::RunPreTypeCheckPass(
ProcStateVisitor state_visitor(import_data, state_struct_def);
XLS_RETURN_IF_ERROR(module.Accept(&state_visitor));
}
XLS_RETURN_IF_ERROR(RewriteLambdas(module, import_data.file_table()));
XLS_RETURN_IF_ERROR(RewriteLambdas(module, import_data));

AddSpawnTraitToProcDefs add_spawn_trait;
XLS_RETURN_IF_ERROR(module.Accept(&add_spawn_trait));
Expand Down
8 changes: 4 additions & 4 deletions xls/dslx/type_system_v2/populate_table_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2066,10 +2066,10 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor,
for (const StructMemberNode* member : node->members()) {
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(member, member->type()));
}
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(
node,
CreateStructOrProcAnnotation(
module_, const_cast<StructDefBase*>(node), {}, std::nullopt)));
XLS_RETURN_IF_ERROR(table_.SetTypeAnnotation(
node,
CreateStructOrProcAnnotation(module_, const_cast<StructDefBase*>(node),
{}, std::nullopt)));

return DefaultHandler(node);
}
Expand Down
38 changes: 36 additions & 2 deletions xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,7 @@ const_assert!(ONE == [[u16:1, 1, 1, 1, 1],
HasNodeWithType("TWO", "uN[32][5][4]"))));
}

// TODO: erinzmoore - Support local struct types.
TEST(TypecheckV2Test, DISABLED_LambdaUsesLocalStructType) {
TEST(TypecheckV2Test, LambdaUsesLocalStructType) {
EXPECT_THAT(
R"(
struct S<N: u32> {
Expand All @@ -754,5 +753,40 @@ const_assert!(RES[0] == S<8>{x: 0});
TypecheckSucceeds(HasNodeWithType("RES", "S { x: uN[8] }[5]")));
}

TEST(TypecheckV2Test, LambdaUsesLocalStructTypeExplicitReturn) {
EXPECT_THAT(
R"(
struct S<N: u32> {
x: uN[N]
}

fn main() -> S<8>[5] {
type MyS = S<8>;
map(u8:0..5, |i| -> MyS { MyS{x: i} })
}

const RES = main();
const_assert!(RES[0] == S<8>{x: 0});
)",
TypecheckSucceeds(HasNodeWithType("RES", "S { x: uN[8] }[5]")));
}

TEST(TypecheckV2Test, LambdaUsesLocalStructTypeMismatch) {
EXPECT_THAT(
R"(
struct S<N: u32> {
x: uN[N]
}

fn main() -> S<16>[5] {
type MyS = S<8>;
map(u16:0..5, |i| { MyS{x: i} })
}

const RES = main();
)",
TypecheckFails(HasSizeMismatch("uN[16]", "uN[8]")));
}

} // namespace
} // namespace xls::dslx
Loading