Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 74 additions & 20 deletions xls/dslx/frontend/semantics_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -267,7 +268,8 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
}
if (node->GetDefiner() == nullptr ||
(node->GetDefiner()->kind() != AstNodeKind::kFunction &&
node->GetDefiner()->kind() != AstNodeKind::kImport)) {
node->GetDefiner()->kind() != AstNodeKind::kImport &&
node->GetDefiner()->kind() != AstNodeKind::kConstantDef)) {
XLS_RETURN_IF_ERROR(AddNameRef(node));
if (node->GetDefiner() != nullptr && in_type_annotation_) {
XLS_RETURN_IF_ERROR(node->GetDefiner()->Accept(this));
Expand Down Expand Up @@ -299,16 +301,17 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
return it->second.name_refs;
}

absl::flat_hash_set<const NameDef*> ConstsDefinedPrior(
const Pos start) const {
absl::flat_hash_set<const NameDef*> result;
return NameDefsDefinedPriorInternal(start, is_const);
}

absl::flat_hash_set<const NameDef*> NameDefsDefinedPrior(
const Pos start) const {
absl::flat_hash_set<const NameDef*> result;
for (const auto& [name_def, info] : name_ref_info_) {
if (!info.any_used_in_type_annotation &&
name_def->span().start() < start) {
result.insert(name_def);
}
}
return result;
return NameDefsDefinedPriorInternal(
start, [](const NameDef* name_def) { return !is_const(name_def); });
}

private:
Expand All @@ -333,6 +336,30 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
return absl::OkStatus();
}

static bool is_const(const NameDef* name_def) {
if (name_def->definer() == nullptr ||
name_def->definer()->kind() != AstNodeKind::kLet) {
return false;
}
auto let = absl::down_cast<Let*>(name_def->definer());
return let->is_const();
}

absl::flat_hash_set<const NameDef*> NameDefsDefinedPriorInternal(
const Pos start,
std::function<bool(const NameDef*)> name_def_filter) const {
absl::flat_hash_set<const NameDef*> result;
for (const auto& [name_def, info] : name_ref_info_) {
if (!info.any_used_in_type_annotation &&
name_def->span().start() < start) {
if (name_def_filter(name_def)) {
result.insert(name_def);
}
}
}
return result;
}

absl::flat_hash_map<const NameDef*, NameRefInfo> name_ref_info_;
bool in_type_annotation_ = false;
};
Expand Down Expand Up @@ -368,7 +395,6 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
Span span = node->span();
CollectNameRefs collect_nr;
XLS_RETURN_IF_ERROR(node->body()->Accept(&collect_nr));
absl::flat_hash_set<const NameDef*> seen;

// If there are any parametric bindings in the containing function that are
// referenced in the lambda, they should be added as parametric bindings to
Expand Down Expand Up @@ -410,11 +436,33 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
}
}

// Any constant that is referenced by the lambda should be added as an impl
// member.
std::vector<ImplMember> impl_members;
absl::flat_hash_set<const NameDef*> impl_constants;
for (const NameDef* constant_nd :
collect_nr.ConstsDefinedPrior(span.start())) {
if (parametric_nds.contains(constant_nd)) {
continue;
}
XLS_RET_CHECK(constant_nd->definer() != nullptr &&
constant_nd->definer()->kind() == AstNodeKind::kLet);
auto let = absl::down_cast<Let*>(constant_nd->definer());
NameDef* new_nd = module->Make<NameDef>(
constant_nd->span(), constant_nd->identifier(), /*definer=*/nullptr);
ConstantDef* constant_def = module->Make<ConstantDef>(
constant_nd->span(), new_nd, let->type_annotation(), let->rhs(),
/*is_public=*/false);
impl_members.push_back(constant_def);
impl_constants.insert(constant_nd);
}

// For any NameDef that is referenced in the lambda, but defined prior to
// the lambda, it must be captured in the struct instance, unless it was
// already added as a parametric binding.
std::vector<StructMemberNode*> struct_members;
std::vector<std::pair<std::string, Expr*>> struct_instance_members;
absl::flat_hash_set<const NameDef*> self_members;
for (const NameDef* original_name_def :
collect_nr.NameDefsDefinedPrior(span.start())) {
if (!parametric_nds.contains(original_name_def)) {
Expand Down Expand Up @@ -450,7 +498,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
original_name_def);
struct_instance_members.push_back(std::make_pair(
original_name_def->identifier(), struct_instance_nr));
seen.insert(original_name_def);
self_members.insert(original_name_def);
}
}

Expand Down Expand Up @@ -479,7 +527,8 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
NameDef* self_nd = module->Make<NameDef>(
span, KeywordToString(Keyword::kSelf), /*definer=*/nullptr);
CloneReplacer insert_self =
[self_nd, seen, name_ref_replacements](
[self_nd, self_members, name_ref_replacements, impl_constants,
struct_type_annotation](
const AstNode* node, const Module* _,
const absl::flat_hash_map<const AstNode*, AstNode*>& replacements)
-> std::optional<AstNode*> {
Expand All @@ -489,37 +538,42 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
return std::nullopt;
}
const auto* name_def = std::get<const NameDef*>(name_ref->name_def());
if (name_def != nullptr && seen.contains(name_def)) {
if (name_def != nullptr && self_members.contains(name_def)) {
NameRef* self_nr = node->owner()->Make<NameRef>(
name_ref->span(), self_nd->identifier(), self_nd);
return node->owner()->Make<Attr>(name_def->span(), self_nr,
name_def->identifier(),
/* in_parens= */ false);
} else if (name_def != nullptr && impl_constants.contains(name_def)) {
return node->owner()->Make<ColonRef>(
name_def->span(), struct_type_annotation, name_def->identifier());
} else if (name_ref_replacements.contains(name_ref)) {
return name_ref_replacements.at(name_ref);
}
}
return std::nullopt;
};
XLS_ASSIGN_OR_RETURN(
AstNode * cloned_body,
CloneAst(original_fn->body(),
AstNode * clone,
CloneAst(original_fn,
ChainCloneReplacers(&PreserveTypeDefinitionsReplacer,
std::move(insert_self))));
Function* cloned_function = absl::down_cast<Function*>(clone);
SelfTypeAnnotation* self_type = module->Make<SelfTypeAnnotation>(
span, /*explicit_type=*/false, struct_type_annotation);
std::vector<Param*> params = {module->Make<Param>(self_nd, self_type)};
for (auto* param : original_fn->params()) {
for (auto* param : cloned_function->params()) {
params.push_back(param);
}
Function* impl_fn = module->Make<Function>(
original_fn->span(), original_fn->name_def(),
original_fn->parametric_bindings(), params, original_fn->return_type(),
absl::down_cast<StatementBlock*>(cloned_body),
cloned_function->span(), cloned_function->name_def(),
cloned_function->parametric_bindings(), params,
cloned_function->return_type(),
absl::down_cast<StatementBlock*>(cloned_function->body()),
FunctionTag::kGeneratedFromLambda,
/*is_public=*/false, /*is_stub=*/false);
Impl* impl = module->Make<Impl>(span, struct_type_annotation,
std::vector<ImplMember>{impl_fn},
impl_members.push_back(impl_fn);
Impl* impl = module->Make<Impl>(span, struct_type_annotation, impl_members,
/*is_public=*/false);
impl_fn->set_impl(impl);
full_struct_def->set_impl(impl);
Expand Down
100 changes: 99 additions & 1 deletion xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,111 @@ const_assert!(RES2 == [u32:0, 3, 2]);
HasNodeWithType("RES2", "uN[32][3]"))));
}

TEST(TypecheckV2Test, NestedLambdaIteratesOverLocalConst) {
EXPECT_THAT(
R"(
fn nested() -> u1[2][3] {
const X = u32:2;
const Y = u32:3;
map(0..Y, | y_idx: u32 | {
map(0..X, | x_idx: u32 | {
if (x_idx + y_idx) % 2 == 0 {
u1:1
} else {
u1:0
}
})
})
}

const RES = nested();
const EX = [
[u1:1, u1:0],
[u1:0, u1:1],
[u1:1, u1:0],
];
const_assert!(RES == EX);

)",
TypecheckSucceeds(AllOf(
HasNodeWithType("RES", "uN[1][2][3]"),
HasNodeWithType("lambda_capture_struct_at_fake.x:7:14-15:5::X",
"uN[32]"),
HasNodeWithType("lambda_capture_struct_at_fake.x:8:18-14:9<u32>",
"typeof(lambda_capture_struct_at_fake.x:8:18-14:9 { "
"y_idx: uN[32] }"))));
}

TEST(TypecheckV2Test, NestedLambdaIteratesOverLocalConstWithExplicitReturn) {
EXPECT_THAT(
R"(
fn nested() -> u1[2][3] {
const X = u32:2;
const Y = u32:3;
map(0..Y, | y_idx: u32 | -> u1[X] {
map(0..X, | x_idx: u32 | {
if (x_idx + y_idx) % 2 == 0 {
u1:1
} else {
u1:0
}
})
})
}

const RES = nested();
const EX = [
[u1:1, u1:0],
[u1:0, u1:1],
[u1:1, u1:0],
];
const_assert!(RES == EX);

)",
TypecheckSucceeds(
AllOf(HasNodeWithType("RES", "uN[1][2][3]"),
HasNodeWithType("lambda_capture_struct_at_fake.x:7:14-15:5::X",
"uN[32]"))));
}

TEST(TypecheckV2Test, NestedLambdaIteratesOverGlobalConst) {
EXPECT_THAT(
R"(
const X = u32:2;
const Y = u32:3;
type Results = u1[X][Y];

fn nested() -> Results {
map(0..Y, | y_idx | {
map(0..X, | x_idx | {
if (x_idx + y_idx) % 2 == 0 {
1
} else {
0
}
})
})
}

const RES = nested();
const EX = [
[u1:1, u1:0],
[u1:0, u1:1],
[u1:1, u1:0],
];
const_assert!(RES == EX);

)",
TypecheckSucceeds(HasNodeWithType("RES", "uN[1][2][3]")));
}

TEST(TypecheckV2Test, LambdaUsesUnrollForOutput) {
EXPECT_THAT(
R"(
const A = u32:1;
fn foo() -> u32[5] {
let B = u32:2;
const X = unroll_for! (i, a) in u32:0..5 {
let X = unroll_for! (i, a) in u32:0..5 {
let C = B + i;
let D = A * a;
C + D
Expand Down
Loading