Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xls/dslx/frontend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -737,13 +737,16 @@ cc_library(
hdrs = ["semantics_analysis.h"],
deps = [
":ast",
":ast_cloner",
":ast_node",
":ast_node_visitor_with_default",
":ast_utils",
":bindings",
":module",
":pos",
":token",
":token_utils",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/dslx:import_data",
"//xls/dslx:warning_collector",
Expand Down
2 changes: 0 additions & 2 deletions xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ std::string_view FunctionTagToString(FunctionTag tag) {
return "proc next";
case FunctionTag::kProcInit:
return "proc init";
case FunctionTag::kLambda:
return "lambda";
}
LOG(FATAL) << "Out-of-range function tag: " << static_cast<int>(tag);
}
Expand Down
5 changes: 4 additions & 1 deletion xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -2453,7 +2453,6 @@ enum class FunctionTag : uint8_t {
kProcConfig,
kProcNext,
kProcInit,
kLambda,
};

std::string_view FunctionTagToString(FunctionTag tag);
Expand Down Expand Up @@ -2745,6 +2744,8 @@ class Instantiation : public Expr {
return explicit_parametrics_;
}

void set_callee(Expr* callee) { callee_ = callee; }

protected:
std::string FormatParametrics() const;

Expand Down Expand Up @@ -3394,6 +3395,8 @@ class Impl : public AstNode {

AstNodeKind kind() const override { return AstNodeKind::kImpl; }

static std::string_view GetDebugTypeName() { return "impl"; }

absl::Status Accept(AstNodeVisitor* v) const override {
return v->HandleImpl(this);
}
Expand Down
3 changes: 2 additions & 1 deletion xls/dslx/frontend/ast_cloner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ fn main() -> u32 {

TEST(AstClonerTest, Lambda) {
constexpr std::string_view kProgram = R"(fn main() -> u32[10] {
let ARR = map(range(0, 10), |i: u32| -> u32 { 2 * i });
let a = u32:0;
let ARR = map(range(0, 10), |i: u32| -> u32 { a * i });
ARR
})";

Expand Down
29 changes: 29 additions & 0 deletions xls/dslx/frontend/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,35 @@ absl::Status Module::AddTop(ModuleMember member,
return absl::OkStatus();
}

absl::Status Module::InsertTop(ModuleMember member,
const MakeCollisionError& make_collision_error) {
std::vector<std::string> member_names = GetMemberNames(member);

for (const std::string& member_name : member_names) {
if (top_by_name_.contains(member_name)) {
const AstNode* node = ToAstNode(top_by_name_.at(member_name));
const Span existing_span = node->GetSpan().value();
const AstNode* new_node = ToAstNode(member);
const Span new_span = new_node->GetSpan().value();
if (make_collision_error != nullptr) {
return make_collision_error(name_, member_name, existing_span, node,
new_span, new_node);
}
return absl::InvalidArgumentError(absl::StrFormat(
"Module %s already contains a member named %s @ %s: %s", name_,
member_name, existing_span.ToString(*file_table_), node->ToString()));
}
}

top_.insert(top_.begin(), member);

top_set_.insert(ToAstNode(member));
for (const std::string& member_name : member_names) {
top_by_name_.insert({member_name, member});
}
return absl::OkStatus();
}

absl::Status Module::InsertTopAfter(
const AstNode* target_member, ModuleMember member,
const MakeCollisionError& make_collision_error) {
Expand Down
6 changes: 6 additions & 0 deletions xls/dslx/frontend/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ class Module : public AstNode {
const AstNode* target_member, ModuleMember member,
const MakeCollisionError& make_collision_error = nullptr);

// Inserts a top-level member as the first member of the module. Behaves like
// `AddTop` with respect to collision checks.
absl::Status InsertTop(
ModuleMember member,
const MakeCollisionError& make_collision_error = nullptr);

// Gets the element in this module with the given target_name, or returns a
// NotFoundError.
template <typename T>
Expand Down
4 changes: 3 additions & 1 deletion xls/dslx/frontend/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,10 +339,12 @@ absl::StatusOr<Lambda*> Parser::ParseLambda(Bindings& bindings) {
module_->Make<NameDef>(sp, std::string(Lambda::kCallLambdaFn), nullptr);
Function* fn =
module_->Make<Function>(sp, fn_name_def, parametrics, params, return_type,
body, FunctionTag::kLambda,
body, FunctionTag::kNormal,
/*is_public=*/false, /*is_stub=*/false);
fn_name_def->set_definer(fn);

// TODO: erinzmoore - Switch lambda to just wrap a function and don't
// create the struct/impl here.
NameDef* struct_name_def =
module_->Make<NameDef>(sp,
absl::Substitute("lambda_capture_struct_at_$0",
Expand Down
200 changes: 200 additions & 0 deletions xls/dslx/frontend/semantics_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,18 @@
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "absl/strings/substitute.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/dslx/frontend/ast.h"
#include "xls/dslx/frontend/ast_cloner.h"
#include "xls/dslx/frontend/ast_node.h"
#include "xls/dslx/frontend/ast_node_visitor_with_default.h"
#include "xls/dslx/frontend/ast_utils.h"
#include "xls/dslx/frontend/bindings.h"
#include "xls/dslx/frontend/module.h"
#include "xls/dslx/frontend/pos.h"
#include "xls/dslx/frontend/token.h"
#include "xls/dslx/frontend/token_utils.h"
#include "xls/dslx/import_data.h"
#include "xls/dslx/type_system/deduce_utils.h"
Expand Down Expand Up @@ -250,6 +254,199 @@ class SideEffectExpressionFinder : public AstNodeVisitorWithDefault {
bool has_side_effect_;
};

class CollectNameRefs : public AstNodeVisitorWithDefault {
public:
absl::Status HandleNameRef(const NameRef* node) override {
name_refs_.insert(node);
return absl::OkStatus();
}

absl::Status DefaultHandler(const AstNode* node) override {
for (const AstNode* child : node->GetChildren(/*want_types=*/false)) {
XLS_RETURN_IF_ERROR(child->Accept(this));
}
return absl::OkStatus();
}

const absl::flat_hash_set<const NameRef*>& name_refs() const {
return name_refs_;
}

private:
absl::flat_hash_set<const NameRef*> name_refs_;
};

class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
public:
ReplaceLambdaWithInvocation(const FileTable& file_table)
: file_table_(file_table) {}

// Converts a lambda into a struct instance and an impl. For example,
//
// fn add_two(arr: u32[5]) -> u32[5] {
// let x = u32:2;
// map(arr, |i: u32| -> u32 { x + i })
// }
//
// becomes:
//
// struct lambda_capture { x: u32 }
//
// impl lambda_capture {
// fn call(self) -> u32 { self.x + i }
// }
//
// fn add_two(arr: u32[5]) -> u32[5] {
// let x = u32:2;
// map(arr, lambda_capture{x: x}.call)
// }
absl::Status HandleLambda(const Lambda* node) override {
Module* module = node->owner();
Function* original_fn = node->function();
Span span = node->span();
CollectNameRefs collect_nr;
XLS_RETURN_IF_ERROR(node->body()->Accept(&collect_nr));
absl::flat_hash_set<const NameDef*> seen;

// For any NameRef in the lambda body, if it references a NameDef outside
// the lambda, add as a member to the struct def associated with the lambda.
std::vector<ParametricBinding*> struct_parametrics;
std::vector<StructMemberNode*> struct_members;
std::vector<std::pair<std::string, Expr*>> struct_instance_members;
for (const NameRef* name_ref : collect_nr.name_refs()) {
if (const NameDef* original_name_def =
std::get<const NameDef*>(name_ref->name_def());
original_name_def != nullptr &&
original_name_def->span().start() < span.start() &&
!seen.contains(original_name_def)) {
// Create parametric binding with generic type to use for the context
// variable type.
GenericTypeAnnotation* gta =
module->Make<GenericTypeAnnotation>(name_ref->span());
NameDef* generic_name_def = module->Make<NameDef>(
name_ref->span(),
absl::Substitute("parametric_type_for_$0",
original_name_def->identifier()),
/*definer=*/gta);
NameRef* generic_name_ref = module->Make<NameRef>(
name_ref->span(), generic_name_def->identifier(), generic_name_def);
struct_parametrics.push_back(module->Make<ParametricBinding>(
generic_name_def, gta, /*expr=*/nullptr));

NameDef* struct_member_nd = module->Make<NameDef>(
name_ref->span(), original_name_def->identifier(),
/*definer=*/nullptr);
TypeVariableTypeAnnotation* tvta =
module->Make<TypeVariableTypeAnnotation>(generic_name_ref,
/*internal=*/true);
StructMemberNode* struct_member = module->Make<StructMemberNode>(
Span::None(), struct_member_nd, Span::None(), tvta);
struct_members.push_back(struct_member);

// Make a name ref that points to the original name def. Add as a member
// to a new struct instance.
NameRef* struct_instance_nr = module->Make<NameRef>(
name_ref->span(), original_name_def->identifier(),
original_name_def);
struct_instance_members.push_back(std::make_pair(
original_name_def->identifier(), struct_instance_nr));
seen.insert(original_name_def);
}
}

NameDef* struct_nd =
module->Make<NameDef>(span,
absl::Substitute("lambda_capture_struct_at_$0",
span.ToString(file_table_)),
/*definer=*/nullptr);
StructDef* full_struct_def =
module->Make<StructDef>(span, struct_nd, struct_parametrics,
struct_members, /*is_public=*/false);
TypeRefTypeAnnotation* struct_type_annotation =
module->Make<TypeRefTypeAnnotation>(
span, module->Make<TypeRef>(span, full_struct_def),
/*parametric=*/std::vector<ExprOrType>());
struct_nd->set_definer(full_struct_def);

StructInstance* struct_instance = module->Make<StructInstance>(
span, struct_type_annotation, struct_instance_members);

Attr* instance_invocation = module->Make<Attr>(
span, struct_instance, std::string(Lambda::kCallLambdaFn));

// For every NameRef in the body, if it references a NameDef that has been
// captured, replace it with a reference to the struct member.
NameDef* self_nd = module->Make<NameDef>(
span, KeywordToString(Keyword::kSelf), /*definer=*/nullptr);
CloneReplacer insert_self =
[self_nd, seen](
const AstNode* node, const Module* _,
const absl::flat_hash_map<const AstNode*, AstNode*>& replacements)
-> std::optional<AstNode*> {
if (node->kind() == AstNodeKind::kNameRef) {
const NameRef* name_ref = absl::down_cast<const NameRef*>(node);
const auto* name_def = std::get<const NameDef*>(name_ref->name_def());
if (name_def != nullptr && seen.contains(name_def)) {
NameRef* self_nr = node->owner()->Make<NameRef>(
name_ref->span(), self_nd->identifier(), self_nd);
return node->owner()->Make<Attr>(name_def->span(), self_nr,
name_def->identifier(),
/* in_parens= */ false);
}
}
return std::nullopt;
};
XLS_ASSIGN_OR_RETURN(
AstNode * cloned_body,
CloneAst(original_fn->body(),
ChainCloneReplacers(&PreserveTypeDefinitionsReplacer,
std::move(insert_self))));
SelfTypeAnnotation* self_type = module->Make<SelfTypeAnnotation>(
span, /*explicit_type=*/false, struct_type_annotation);
std::vector<Param*> params = {module->Make<Param>(self_nd, self_type)};
for (auto* param : original_fn->params()) {
params.push_back(param);
}
Function* impl_fn = module->Make<Function>(
original_fn->span(), original_fn->name_def(),
original_fn->parametric_bindings(), params, original_fn->return_type(),
absl::down_cast<StatementBlock*>(cloned_body), FunctionTag::kNormal,
/*is_public=*/false, /*is_stub=*/false);
Impl* impl = module->Make<Impl>(span, struct_type_annotation,
std::vector<ImplMember>{impl_fn},
/*is_public=*/false);
impl_fn->set_impl(impl);
full_struct_def->set_impl(impl);

// Swap the Lambda in its parent with the attr invocation. After this step,
// Lambdas should no longer appear in the AST.
auto* parent_inv = absl::down_cast<Invocation*>(node->parent());
XLS_RET_CHECK(parent_inv != nullptr);
if (parent_inv->callee() == node) {
parent_inv->set_callee(instance_invocation);
} else {
for (int i = 0; i < parent_inv->args().size(); ++i) {
if (parent_inv->args()[i] == node) {
parent_inv->set_arg(i, instance_invocation);
}
}
}

XLS_RETURN_IF_ERROR(module->InsertTop(full_struct_def));
return module->InsertTopAfter(full_struct_def, impl);
}

absl::Status DefaultHandler(const AstNode* node) override {
for (const AstNode* child : node->GetChildren(/*want_types=*/false)) {
XLS_RETURN_IF_ERROR(child->Accept(this));
}
return absl::OkStatus();
}

private:
const FileTable& file_table_;
};

class PreTypecheckPass : public AstNodeVisitorWithDefault {
public:
PreTypecheckPass(WarningCollector& warning_collector,
Expand Down Expand Up @@ -526,6 +723,9 @@ absl::Status SemanticsAnalysis::RunPreTypeCheckPass(
NextParamStateVisitor next_param_visitor(state_struct_def);
XLS_RETURN_IF_ERROR(module.Accept(&next_param_visitor));
}
ReplaceLambdaWithInvocation lambda_pass(import_data.file_table());
XLS_RETURN_IF_ERROR(module.Accept(&lambda_pass));

PreTypecheckPass pass(warning_collector, import_data.file_table());

for (const ModuleMember& top : module.top()) {
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/type_system/type_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,8 @@ absl::StatusOr<TypeInfo*> TypeInfoOwner::GetRootTypeInfo(const Module* module) {

void TypeInfo::NoteConstExpr(const AstNode* const_expr, InterpValue value) {
VLOG(5) << absl::StreamFormat(
"noting node: `%s` (%p) has constexpr value: `%s`",
const_expr->ToString(), const_expr, value.ToString());
"noting node: `%s` (%p) has constexpr value: `%s` in TypeInfo %p",
const_expr->ToString(), const_expr, value.ToString(), this);

// Note: this assertion will generally hold as of 2024-08-23, except in the
// case of `ConstFor` nodes, which https://github.com/richmckeever is
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/type_system_v2/constant_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ class Visitor : public AstNodeVisitorWithDefault {
ti_->NoteConstExpr(constant_def, *value);
ti_->NoteConstExpr(constant_def->value(), *value);
ti_->NoteConstExpr(constant_def->name_def(), *value);
} else {
VLOG(6) << "Constant def: " << constant_def->ToString()
<< " failed constexpr eval: " << value.status();
}
return absl::OkStatus();
}
Expand Down Expand Up @@ -366,6 +369,7 @@ class Visitor : public AstNodeVisitorWithDefault {
&import_data_, ti_, &warning_collector_,
table_.GetParametricEnv(parametric_context_), node->arg());
if (!value.ok()) {
VLOG(6) << "Failed to const_assert " << value.status();
return NotConstantErrorStatus(node->span(), node->arg(), file_table_);
}
VLOG(6) << "Evaluated const assert: " << node->arg()->ToString()
Expand Down
4 changes: 0 additions & 4 deletions xls/dslx/type_system_v2/function_resolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,6 @@ class FunctionResolverImpl : public FunctionResolver {
if (target.has_value()) {
function_node = *target;
}
} else if (callee->kind() == AstNodeKind::kLambda) {
function_node = absl::down_cast<const Lambda*>(callee)->function();
XLS_RETURN_IF_ERROR(converter_.ConvertSubtree(
function_node, caller_function, caller_context));
} else if (callee->kind() == AstNodeKind::kNameRef) {
// Either a local function or a built-in function call.
const auto* name_ref = absl::down_cast<const NameRef*>(callee);
Expand Down
Loading
Loading