From ee46e31576e9fdcef4a8c60736af7f8c02f2fbb1 Mon Sep 17 00:00:00 2001 From: Erin Moore Date: Fri, 24 Apr 2026 11:16:59 -0700 Subject: [PATCH] Add local constants referenced in a lambda as constant members in the capture impl. PiperOrigin-RevId: 905129759 --- xls/dslx/frontend/semantics_analysis.cc | 94 ++++++++++++---- .../typecheck_module_v2_lambda_test.cc | 100 +++++++++++++++++- 2 files changed, 173 insertions(+), 21 deletions(-) diff --git a/xls/dslx/frontend/semantics_analysis.cc b/xls/dslx/frontend/semantics_analysis.cc index 19abfda5fe..f1f368cdeb 100644 --- a/xls/dslx/frontend/semantics_analysis.cc +++ b/xls/dslx/frontend/semantics_analysis.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -267,7 +268,8 @@ class CollectNameRefs : public AstNodeVisitorWithDefault { } if (node->GetDefiner() == nullptr || (node->GetDefiner()->kind() != AstNodeKind::kFunction && - node->GetDefiner()->kind() != AstNodeKind::kImport)) { + node->GetDefiner()->kind() != AstNodeKind::kImport && + node->GetDefiner()->kind() != AstNodeKind::kConstantDef)) { XLS_RETURN_IF_ERROR(AddNameRef(node)); if (node->GetDefiner() != nullptr && in_type_annotation_) { XLS_RETURN_IF_ERROR(node->GetDefiner()->Accept(this)); @@ -299,16 +301,17 @@ class CollectNameRefs : public AstNodeVisitorWithDefault { return it->second.name_refs; } + absl::flat_hash_set ConstsDefinedPrior( + const Pos start) const { + absl::flat_hash_set result; + return NameDefsDefinedPriorInternal(start, is_const); + } + absl::flat_hash_set NameDefsDefinedPrior( const Pos start) const { absl::flat_hash_set result; - for (const auto& [name_def, info] : name_ref_info_) { - if (!info.any_used_in_type_annotation && - name_def->span().start() < start) { - result.insert(name_def); - } - } - return result; + return NameDefsDefinedPriorInternal( + start, [](const NameDef* name_def) { return !is_const(name_def); }); } private: @@ -333,6 +336,30 @@ class CollectNameRefs : public AstNodeVisitorWithDefault { return absl::OkStatus(); } + static bool is_const(const NameDef* name_def) { + if (name_def->definer() == nullptr || + name_def->definer()->kind() != AstNodeKind::kLet) { + return false; + } + auto let = absl::down_cast(name_def->definer()); + return let->is_const(); + } + + absl::flat_hash_set NameDefsDefinedPriorInternal( + const Pos start, + std::function name_def_filter) const { + absl::flat_hash_set result; + for (const auto& [name_def, info] : name_ref_info_) { + if (!info.any_used_in_type_annotation && + name_def->span().start() < start) { + if (name_def_filter(name_def)) { + result.insert(name_def); + } + } + } + return result; + } + absl::flat_hash_map name_ref_info_; bool in_type_annotation_ = false; }; @@ -368,7 +395,6 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault { Span span = node->span(); CollectNameRefs collect_nr; XLS_RETURN_IF_ERROR(node->body()->Accept(&collect_nr)); - absl::flat_hash_set seen; // If there are any parametric bindings in the containing function that are // referenced in the lambda, they should be added as parametric bindings to @@ -410,11 +436,33 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault { } } + // Any constant that is referenced by the lambda should be added as an impl + // member. + std::vector impl_members; + absl::flat_hash_set impl_constants; + for (const NameDef* constant_nd : + collect_nr.ConstsDefinedPrior(span.start())) { + if (parametric_nds.contains(constant_nd)) { + continue; + } + XLS_RET_CHECK(constant_nd->definer() != nullptr && + constant_nd->definer()->kind() == AstNodeKind::kLet); + auto let = absl::down_cast(constant_nd->definer()); + NameDef* new_nd = module->Make( + constant_nd->span(), constant_nd->identifier(), /*definer=*/nullptr); + ConstantDef* constant_def = module->Make( + constant_nd->span(), new_nd, let->type_annotation(), let->rhs(), + /*is_public=*/false); + impl_members.push_back(constant_def); + impl_constants.insert(constant_nd); + } + // For any NameDef that is referenced in the lambda, but defined prior to // the lambda, it must be captured in the struct instance, unless it was // already added as a parametric binding. std::vector struct_members; std::vector> struct_instance_members; + absl::flat_hash_set self_members; for (const NameDef* original_name_def : collect_nr.NameDefsDefinedPrior(span.start())) { if (!parametric_nds.contains(original_name_def)) { @@ -450,7 +498,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault { original_name_def); struct_instance_members.push_back(std::make_pair( original_name_def->identifier(), struct_instance_nr)); - seen.insert(original_name_def); + self_members.insert(original_name_def); } } @@ -479,7 +527,8 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault { NameDef* self_nd = module->Make( span, KeywordToString(Keyword::kSelf), /*definer=*/nullptr); CloneReplacer insert_self = - [self_nd, seen, name_ref_replacements]( + [self_nd, self_members, name_ref_replacements, impl_constants, + struct_type_annotation]( const AstNode* node, const Module* _, const absl::flat_hash_map& replacements) -> std::optional { @@ -489,12 +538,15 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault { return std::nullopt; } const auto* name_def = std::get(name_ref->name_def()); - if (name_def != nullptr && seen.contains(name_def)) { + if (name_def != nullptr && self_members.contains(name_def)) { NameRef* self_nr = node->owner()->Make( name_ref->span(), self_nd->identifier(), self_nd); return node->owner()->Make(name_def->span(), self_nr, name_def->identifier(), /* in_parens= */ false); + } else if (name_def != nullptr && impl_constants.contains(name_def)) { + return node->owner()->Make( + name_def->span(), struct_type_annotation, name_def->identifier()); } else if (name_ref_replacements.contains(name_ref)) { return name_ref_replacements.at(name_ref); } @@ -502,24 +554,26 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault { return std::nullopt; }; XLS_ASSIGN_OR_RETURN( - AstNode * cloned_body, - CloneAst(original_fn->body(), + AstNode * clone, + CloneAst(original_fn, ChainCloneReplacers(&PreserveTypeDefinitionsReplacer, std::move(insert_self)))); + Function* cloned_function = absl::down_cast(clone); SelfTypeAnnotation* self_type = module->Make( span, /*explicit_type=*/false, struct_type_annotation); std::vector params = {module->Make(self_nd, self_type)}; - for (auto* param : original_fn->params()) { + for (auto* param : cloned_function->params()) { params.push_back(param); } Function* impl_fn = module->Make( - original_fn->span(), original_fn->name_def(), - original_fn->parametric_bindings(), params, original_fn->return_type(), - absl::down_cast(cloned_body), + cloned_function->span(), cloned_function->name_def(), + cloned_function->parametric_bindings(), params, + cloned_function->return_type(), + absl::down_cast(cloned_function->body()), FunctionTag::kGeneratedFromLambda, /*is_public=*/false, /*is_stub=*/false); - Impl* impl = module->Make(span, struct_type_annotation, - std::vector{impl_fn}, + impl_members.push_back(impl_fn); + Impl* impl = module->Make(span, struct_type_annotation, impl_members, /*is_public=*/false); impl_fn->set_impl(impl); full_struct_def->set_impl(impl); diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc index 0e843a7a9f..66a4c2d287 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_lambda_test.cc @@ -249,13 +249,111 @@ const_assert!(RES2 == [u32:0, 3, 2]); HasNodeWithType("RES2", "uN[32][3]")))); } +TEST(TypecheckV2Test, NestedLambdaIteratesOverLocalConst) { + EXPECT_THAT( + R"( +fn nested() -> u1[2][3] { + const X = u32:2; + const Y = u32:3; + map(0..Y, | y_idx: u32 | { + map(0..X, | x_idx: u32 | { + if (x_idx + y_idx) % 2 == 0 { + u1:1 + } else { + u1:0 + } + }) + }) +} + +const RES = nested(); +const EX = [ + [u1:1, u1:0], + [u1:0, u1:1], + [u1:1, u1:0], +]; +const_assert!(RES == EX); + +)", + TypecheckSucceeds(AllOf( + HasNodeWithType("RES", "uN[1][2][3]"), + HasNodeWithType("lambda_capture_struct_at_fake.x:7:14-15:5::X", + "uN[32]"), + HasNodeWithType("lambda_capture_struct_at_fake.x:8:18-14:9", + "typeof(lambda_capture_struct_at_fake.x:8:18-14:9 { " + "y_idx: uN[32] }")))); +} + +TEST(TypecheckV2Test, NestedLambdaIteratesOverLocalConstWithExplicitReturn) { + EXPECT_THAT( + R"( +fn nested() -> u1[2][3] { + const X = u32:2; + const Y = u32:3; + map(0..Y, | y_idx: u32 | -> u1[X] { + map(0..X, | x_idx: u32 | { + if (x_idx + y_idx) % 2 == 0 { + u1:1 + } else { + u1:0 + } + }) + }) +} + +const RES = nested(); +const EX = [ + [u1:1, u1:0], + [u1:0, u1:1], + [u1:1, u1:0], +]; +const_assert!(RES == EX); + +)", + TypecheckSucceeds( + AllOf(HasNodeWithType("RES", "uN[1][2][3]"), + HasNodeWithType("lambda_capture_struct_at_fake.x:7:14-15:5::X", + "uN[32]")))); +} + +TEST(TypecheckV2Test, NestedLambdaIteratesOverGlobalConst) { + EXPECT_THAT( + R"( +const X = u32:2; +const Y = u32:3; +type Results = u1[X][Y]; + +fn nested() -> Results { + map(0..Y, | y_idx | { + map(0..X, | x_idx | { + if (x_idx + y_idx) % 2 == 0 { + 1 + } else { + 0 + } + }) + }) +} + +const RES = nested(); +const EX = [ + [u1:1, u1:0], + [u1:0, u1:1], + [u1:1, u1:0], +]; +const_assert!(RES == EX); + +)", + TypecheckSucceeds(HasNodeWithType("RES", "uN[1][2][3]"))); +} + TEST(TypecheckV2Test, LambdaUsesUnrollForOutput) { EXPECT_THAT( R"( const A = u32:1; fn foo() -> u32[5] { let B = u32:2; - const X = unroll_for! (i, a) in u32:0..5 { + let X = unroll_for! (i, a) in u32:0..5 { let C = B + i; let D = A * a; C + D