From ee46e31576e9fdcef4a8c60736af7f8c02f2fbb1 Mon Sep 17 00:00:00 2001
From: Erin Moore
Date: Fri, 24 Apr 2026 11:16:59 -0700
Subject: [PATCH] Add local constants referenced in a lambda as constant
members in the capture impl.
PiperOrigin-RevId: 905129759
---
xls/dslx/frontend/semantics_analysis.cc | 94 ++++++++++++----
.../typecheck_module_v2_lambda_test.cc | 100 +++++++++++++++++-
2 files changed, 173 insertions(+), 21 deletions(-)
diff --git a/xls/dslx/frontend/semantics_analysis.cc b/xls/dslx/frontend/semantics_analysis.cc
index 19abfda5fe..f1f368cdeb 100644
--- a/xls/dslx/frontend/semantics_analysis.cc
+++ b/xls/dslx/frontend/semantics_analysis.cc
@@ -17,6 +17,7 @@
#include
#include
#include
+#include
#include
#include
#include
@@ -267,7 +268,8 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
}
if (node->GetDefiner() == nullptr ||
(node->GetDefiner()->kind() != AstNodeKind::kFunction &&
- node->GetDefiner()->kind() != AstNodeKind::kImport)) {
+ node->GetDefiner()->kind() != AstNodeKind::kImport &&
+ node->GetDefiner()->kind() != AstNodeKind::kConstantDef)) {
XLS_RETURN_IF_ERROR(AddNameRef(node));
if (node->GetDefiner() != nullptr && in_type_annotation_) {
XLS_RETURN_IF_ERROR(node->GetDefiner()->Accept(this));
@@ -299,16 +301,17 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
return it->second.name_refs;
}
+ absl::flat_hash_set ConstsDefinedPrior(
+ const Pos start) const {
+ absl::flat_hash_set result;
+ return NameDefsDefinedPriorInternal(start, is_const);
+ }
+
absl::flat_hash_set NameDefsDefinedPrior(
const Pos start) const {
absl::flat_hash_set result;
- for (const auto& [name_def, info] : name_ref_info_) {
- if (!info.any_used_in_type_annotation &&
- name_def->span().start() < start) {
- result.insert(name_def);
- }
- }
- return result;
+ return NameDefsDefinedPriorInternal(
+ start, [](const NameDef* name_def) { return !is_const(name_def); });
}
private:
@@ -333,6 +336,30 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
return absl::OkStatus();
}
+ static bool is_const(const NameDef* name_def) {
+ if (name_def->definer() == nullptr ||
+ name_def->definer()->kind() != AstNodeKind::kLet) {
+ return false;
+ }
+ auto let = absl::down_cast(name_def->definer());
+ return let->is_const();
+ }
+
+ absl::flat_hash_set NameDefsDefinedPriorInternal(
+ const Pos start,
+ std::function name_def_filter) const {
+ absl::flat_hash_set result;
+ for (const auto& [name_def, info] : name_ref_info_) {
+ if (!info.any_used_in_type_annotation &&
+ name_def->span().start() < start) {
+ if (name_def_filter(name_def)) {
+ result.insert(name_def);
+ }
+ }
+ }
+ return result;
+ }
+
absl::flat_hash_map name_ref_info_;
bool in_type_annotation_ = false;
};
@@ -368,7 +395,6 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
Span span = node->span();
CollectNameRefs collect_nr;
XLS_RETURN_IF_ERROR(node->body()->Accept(&collect_nr));
- absl::flat_hash_set seen;
// If there are any parametric bindings in the containing function that are
// referenced in the lambda, they should be added as parametric bindings to
@@ -410,11 +436,33 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
}
}
+ // Any constant that is referenced by the lambda should be added as an impl
+ // member.
+ std::vector impl_members;
+ absl::flat_hash_set impl_constants;
+ for (const NameDef* constant_nd :
+ collect_nr.ConstsDefinedPrior(span.start())) {
+ if (parametric_nds.contains(constant_nd)) {
+ continue;
+ }
+ XLS_RET_CHECK(constant_nd->definer() != nullptr &&
+ constant_nd->definer()->kind() == AstNodeKind::kLet);
+ auto let = absl::down_cast(constant_nd->definer());
+ NameDef* new_nd = module->Make(
+ constant_nd->span(), constant_nd->identifier(), /*definer=*/nullptr);
+ ConstantDef* constant_def = module->Make(
+ constant_nd->span(), new_nd, let->type_annotation(), let->rhs(),
+ /*is_public=*/false);
+ impl_members.push_back(constant_def);
+ impl_constants.insert(constant_nd);
+ }
+
// For any NameDef that is referenced in the lambda, but defined prior to
// the lambda, it must be captured in the struct instance, unless it was
// already added as a parametric binding.
std::vector struct_members;
std::vector> struct_instance_members;
+ absl::flat_hash_set self_members;
for (const NameDef* original_name_def :
collect_nr.NameDefsDefinedPrior(span.start())) {
if (!parametric_nds.contains(original_name_def)) {
@@ -450,7 +498,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
original_name_def);
struct_instance_members.push_back(std::make_pair(
original_name_def->identifier(), struct_instance_nr));
- seen.insert(original_name_def);
+ self_members.insert(original_name_def);
}
}
@@ -479,7 +527,8 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
NameDef* self_nd = module->Make(
span, KeywordToString(Keyword::kSelf), /*definer=*/nullptr);
CloneReplacer insert_self =
- [self_nd, seen, name_ref_replacements](
+ [self_nd, self_members, name_ref_replacements, impl_constants,
+ struct_type_annotation](
const AstNode* node, const Module* _,
const absl::flat_hash_map& replacements)
-> std::optional {
@@ -489,12 +538,15 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
return std::nullopt;
}
const auto* name_def = std::get(name_ref->name_def());
- if (name_def != nullptr && seen.contains(name_def)) {
+ if (name_def != nullptr && self_members.contains(name_def)) {
NameRef* self_nr = node->owner()->Make(
name_ref->span(), self_nd->identifier(), self_nd);
return node->owner()->Make(name_def->span(), self_nr,
name_def->identifier(),
/* in_parens= */ false);
+ } else if (name_def != nullptr && impl_constants.contains(name_def)) {
+ return node->owner()->Make(
+ name_def->span(), struct_type_annotation, name_def->identifier());
} else if (name_ref_replacements.contains(name_ref)) {
return name_ref_replacements.at(name_ref);
}
@@ -502,24 +554,26 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
return std::nullopt;
};
XLS_ASSIGN_OR_RETURN(
- AstNode * cloned_body,
- CloneAst(original_fn->body(),
+ AstNode * clone,
+ CloneAst(original_fn,
ChainCloneReplacers(&PreserveTypeDefinitionsReplacer,
std::move(insert_self))));
+ Function* cloned_function = absl::down_cast(clone);
SelfTypeAnnotation* self_type = module->Make(
span, /*explicit_type=*/false, struct_type_annotation);
std::vector params = {module->Make(self_nd, self_type)};
- for (auto* param : original_fn->params()) {
+ for (auto* param : cloned_function->params()) {
params.push_back(param);
}
Function* impl_fn = module->Make(
- original_fn->span(), original_fn->name_def(),
- original_fn->parametric_bindings(), params, original_fn->return_type(),
- absl::down_cast(cloned_body),
+ cloned_function->span(), cloned_function->name_def(),
+ cloned_function->parametric_bindings(), params,
+ cloned_function->return_type(),
+ absl::down_cast(cloned_function->body()),
FunctionTag::kGeneratedFromLambda,
/*is_public=*/false, /*is_stub=*/false);
- Impl* impl = module->Make(span, struct_type_annotation,
- std::vector{impl_fn},
+ impl_members.push_back(impl_fn);
+ Impl* impl = module->Make(span, struct_type_annotation, impl_members,
/*is_public=*/false);
impl_fn->set_impl(impl);
full_struct_def->set_impl(impl);
diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc
index 0e843a7a9f..66a4c2d287 100644
--- a/xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc
+++ b/xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc
@@ -249,13 +249,111 @@ const_assert!(RES2 == [u32:0, 3, 2]);
HasNodeWithType("RES2", "uN[32][3]"))));
}
+TEST(TypecheckV2Test, NestedLambdaIteratesOverLocalConst) {
+ EXPECT_THAT(
+ R"(
+fn nested() -> u1[2][3] {
+ const X = u32:2;
+ const Y = u32:3;
+ map(0..Y, | y_idx: u32 | {
+ map(0..X, | x_idx: u32 | {
+ if (x_idx + y_idx) % 2 == 0 {
+ u1:1
+ } else {
+ u1:0
+ }
+ })
+ })
+}
+
+const RES = nested();
+const EX = [
+ [u1:1, u1:0],
+ [u1:0, u1:1],
+ [u1:1, u1:0],
+];
+const_assert!(RES == EX);
+
+)",
+ TypecheckSucceeds(AllOf(
+ HasNodeWithType("RES", "uN[1][2][3]"),
+ HasNodeWithType("lambda_capture_struct_at_fake.x:7:14-15:5::X",
+ "uN[32]"),
+ HasNodeWithType("lambda_capture_struct_at_fake.x:8:18-14:9",
+ "typeof(lambda_capture_struct_at_fake.x:8:18-14:9 { "
+ "y_idx: uN[32] }"))));
+}
+
+TEST(TypecheckV2Test, NestedLambdaIteratesOverLocalConstWithExplicitReturn) {
+ EXPECT_THAT(
+ R"(
+fn nested() -> u1[2][3] {
+ const X = u32:2;
+ const Y = u32:3;
+ map(0..Y, | y_idx: u32 | -> u1[X] {
+ map(0..X, | x_idx: u32 | {
+ if (x_idx + y_idx) % 2 == 0 {
+ u1:1
+ } else {
+ u1:0
+ }
+ })
+ })
+}
+
+const RES = nested();
+const EX = [
+ [u1:1, u1:0],
+ [u1:0, u1:1],
+ [u1:1, u1:0],
+];
+const_assert!(RES == EX);
+
+)",
+ TypecheckSucceeds(
+ AllOf(HasNodeWithType("RES", "uN[1][2][3]"),
+ HasNodeWithType("lambda_capture_struct_at_fake.x:7:14-15:5::X",
+ "uN[32]"))));
+}
+
+TEST(TypecheckV2Test, NestedLambdaIteratesOverGlobalConst) {
+ EXPECT_THAT(
+ R"(
+const X = u32:2;
+const Y = u32:3;
+type Results = u1[X][Y];
+
+fn nested() -> Results {
+ map(0..Y, | y_idx | {
+ map(0..X, | x_idx | {
+ if (x_idx + y_idx) % 2 == 0 {
+ 1
+ } else {
+ 0
+ }
+ })
+ })
+}
+
+const RES = nested();
+const EX = [
+ [u1:1, u1:0],
+ [u1:0, u1:1],
+ [u1:1, u1:0],
+];
+const_assert!(RES == EX);
+
+)",
+ TypecheckSucceeds(HasNodeWithType("RES", "uN[1][2][3]")));
+}
+
TEST(TypecheckV2Test, LambdaUsesUnrollForOutput) {
EXPECT_THAT(
R"(
const A = u32:1;
fn foo() -> u32[5] {
let B = u32:2;
- const X = unroll_for! (i, a) in u32:0..5 {
+ let X = unroll_for! (i, a) in u32:0..5 {
let C = B + i;
let D = A * a;
C + D