From ed3b9a155a62232af9f1bc31482dd296997df806 Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Tue, 24 Mar 2026 14:51:15 -0700 Subject: [PATCH] [opt] Use a structural fingerprint to break ties in ReassociationPass Introduces a new lazy analysis, NodeFingerprintAnalysis, which computes a structural hash for each node in a function. This fingerprint is stable across changes to node IDs or names, as long as the underlying expression structure remains the same. The ReassociationPass is updated to use this fingerprint in its sorting comparison for associative elements. By including the fingerprint in the sort order, the pass can more consistently break ties between nodes it can't otherwise distinguish except by name or ID, helping to prevent infinite loops where the pass oscillates between two equivalent IR forms. PiperOrigin-RevId: 888868233 --- xls/passes/BUILD | 37 +++++++ xls/passes/node_fingerprint_analysis.cc | 66 +++++++++++ xls/passes/node_fingerprint_analysis.h | 53 +++++++++ xls/passes/node_fingerprint_analysis_test.cc | 111 +++++++++++++++++++ xls/passes/reassociation_pass.cc | 76 ++++++++++++- xls/passes/reassociation_pass_test.cc | 18 +-- 6 files changed, 346 insertions(+), 15 deletions(-) create mode 100644 xls/passes/node_fingerprint_analysis.cc create mode 100644 xls/passes/node_fingerprint_analysis.h create mode 100644 xls/passes/node_fingerprint_analysis_test.cc diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 4b4baa6e7f..e0961629bb 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -2207,6 +2207,7 @@ xls_pass( ":lazy_dag_cache", ":lazy_node_info", ":lazy_ternary_query_engine", + ":node_fingerprint_analysis", ":optimization_pass", ":pass_base", ":query_engine", @@ -3558,6 +3559,7 @@ cc_test( "//xls/common/fuzzing:fuzztest", "//xls/common/status:matchers", "//xls/common/status:status_macros", + "//xls/dev_tools:remove_identifiers", "//xls/fuzzer/ir_fuzzer:ir_fuzz_domain", "//xls/fuzzer/ir_fuzzer:ir_fuzz_test_library", "//xls/ir", @@ -4806,6 +4808,24 @@ cc_library( ], ) +cc_library( + name = "node_fingerprint_analysis", + srcs = ["node_fingerprint_analysis.cc"], + hdrs = ["node_fingerprint_analysis.h"], + deps = [ + ":lazy_dag_cache", + ":lazy_node_data", + "//xls/ir", + "//xls/ir:op", + "//xls/ir:type", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + ], +) + cc_test( name = "visibility_expr_builder_test", srcs = ["visibility_expr_builder_test.cc"], @@ -4872,3 +4892,20 @@ cc_test( "@googletest//:gtest", ], ) + +cc_test( + name = "node_fingerprint_analysis_test", + srcs = ["node_fingerprint_analysis_test.cc"], + deps = [ + ":node_fingerprint_analysis", + "//xls/common:xls_gunit_main", + "//xls/common/status:matchers", + "//xls/ir", + "//xls/ir:bits", + "//xls/ir:function_builder", + "//xls/ir:ir_test_base", + "//xls/ir:value", + "@com_google_absl//absl/status:statusor", + "@googletest//:gtest", + ], +) diff --git a/xls/passes/node_fingerprint_analysis.cc b/xls/passes/node_fingerprint_analysis.cc new file mode 100644 index 0000000000..367bcea627 --- /dev/null +++ b/xls/passes/node_fingerprint_analysis.cc @@ -0,0 +1,66 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/passes/node_fingerprint_analysis.h" + +#include +#include + +#include "absl/hash/hash.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "xls/ir/function_base.h" +#include "xls/ir/node.h" +#include "xls/ir/nodes.h" +#include "xls/ir/op.h" +#include "xls/ir/type.h" + +namespace xls { + +uint64_t NodeFingerprintAnalysis::ComputeInfo( + Node* node, absl::Span operand_fingerprints) const { + std::vector operands; + operands.reserve(operand_fingerprints.size()); + for (const uint64_t* f : operand_fingerprints) { + operands.push_back(*f); + } + + Op op = node->op(); + Type* type = node->GetType(); + + if (node->Is()) { + return absl::HashOf(op, type, operands, node->As()->value()); + } + if (node->Is()) { + return absl::HashOf(op, type, operands, node->As()->start()); + } + if (node->Is()) { + return absl::HashOf(op, type, operands, node->As()->index()); + } + if (node->Is()) { + return absl::HashOf(op, type, operands, + node->As()->channel_name()); + } + if (node->Is()) { + absl::StatusOr param_index = + node->function_base()->GetParamIndex(node->As()); + CHECK_OK(param_index); + return absl::HashOf(op, type, operands, *param_index); + } + + return absl::HashOf(op, type, operands); +} + +} // namespace xls diff --git a/xls/passes/node_fingerprint_analysis.h b/xls/passes/node_fingerprint_analysis.h new file mode 100644 index 0000000000..d846f44c51 --- /dev/null +++ b/xls/passes/node_fingerprint_analysis.h @@ -0,0 +1,53 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_PASSES_NODE_FINGERPRINT_ANALYSIS_H_ +#define XLS_PASSES_NODE_FINGERPRINT_ANALYSIS_H_ + +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "xls/ir/node.h" +#include "xls/passes/lazy_dag_cache.h" +#include "xls/passes/lazy_node_data.h" + +namespace xls { + +// An analysis that computes a structural fingerprint for each node. +// This fingerprint is intended to be stable across optimization passes that +// rename nodes or change their IDs, as long as the underlying expression tree +// remains identical. +class NodeFingerprintAnalysis : public LazyNodeData { + public: + NodeFingerprintAnalysis() + : LazyNodeData(DagCacheInvalidateDirection::kInvalidatesUsers) { + } + + uint64_t GetFingerprint(Node* node) const { return *GetInfo(node); } + + protected: + uint64_t ComputeInfo( + Node* node, + absl::Span operand_fingerprints) const override; + + absl::Status MergeWithGiven(uint64_t& info, + const uint64_t& given) const override { + return absl::InternalError("Cannot merge fingerprints"); + } +}; + +} // namespace xls + +#endif // XLS_PASSES_NODE_FINGERPRINT_ANALYSIS_H_ diff --git a/xls/passes/node_fingerprint_analysis_test.cc b/xls/passes/node_fingerprint_analysis_test.cc new file mode 100644 index 0000000000..ebc7c4db3f --- /dev/null +++ b/xls/passes/node_fingerprint_analysis_test.cc @@ -0,0 +1,111 @@ +// Copyright 2026 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/passes/node_fingerprint_analysis.h" + +#include +#include +#include + +#include "gtest/gtest.h" +#include "absl/status/statusor.h" +#include "xls/common/status/matchers.h" +#include "xls/ir/bits.h" +#include "xls/ir/function_builder.h" +#include "xls/ir/ir_test_base.h" +#include "xls/ir/package.h" +#include "xls/ir/value.h" + +namespace xls { +namespace { + +class NodeFingerprintAnalysisTest : public IrTestBase {}; + +TEST_F(NodeFingerprintAnalysisTest, SimpleFingerprints) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + auto x = fb.Param("x", p->GetBitsType(32)); + auto y = fb.Param("y", p->GetBitsType(32)); + auto add1 = fb.Add(x, y); + auto add2 = fb.Add(x, y); + auto sub = fb.Subtract(x, y); + + XLS_ASSERT_OK_AND_ASSIGN(auto f, fb.Build()); + + NodeFingerprintAnalysis analysis; + XLS_ASSERT_OK(analysis.Attach(f).status()); + + uint64_t fp_x = analysis.GetFingerprint(x.node()); + uint64_t fp_y = analysis.GetFingerprint(y.node()); + uint64_t fp_add1 = analysis.GetFingerprint(add1.node()); + uint64_t fp_add2 = analysis.GetFingerprint(add2.node()); + uint64_t fp_sub = analysis.GetFingerprint(sub.node()); + + EXPECT_NE(fp_x, fp_y); + EXPECT_EQ(fp_add1, fp_add2); + EXPECT_NE(fp_add1, fp_sub); +} + +TEST_F(NodeFingerprintAnalysisTest, IdenticalTreesDifferentNames) { + auto p = CreatePackage(); + + auto build_func = [&](std::string name) -> absl::StatusOr { + FunctionBuilder fb(name, p.get()); + auto x = fb.Param("x", p->GetBitsType(32)); + fb.Add(x, fb.Literal(Value(UBits(1, 32)))); + return fb.Build(); + }; + + XLS_ASSERT_OK_AND_ASSIGN(auto f1, build_func("f1")); + XLS_ASSERT_OK_AND_ASSIGN(auto f2, build_func("f2")); + + NodeFingerprintAnalysis analysis; + XLS_ASSERT_OK(analysis.Attach(f1).status()); + uint64_t fp1 = analysis.GetFingerprint(f1->return_value()); + + XLS_ASSERT_OK(analysis.Attach(f2).status()); + uint64_t fp2 = analysis.GetFingerprint(f2->return_value()); + + EXPECT_EQ(fp1, fp2); +} + +TEST_F(NodeFingerprintAnalysisTest, ParametersTrackedByPosition) { + auto p = CreatePackage(); + + FunctionBuilder fb1("f1", p.get()); + auto x1 = fb1.Param("x", p->GetBitsType(32)); + auto y1 = fb1.Param("y", p->GetBitsType(32)); + fb1.Add(x1, y1); + XLS_ASSERT_OK_AND_ASSIGN(auto f1, fb1.Build()); + + FunctionBuilder fb2("f2", p.get()); + auto y2 = fb2.Param("y", p->GetBitsType(32)); + auto x2 = fb2.Param("x", p->GetBitsType(32)); + fb2.Add(y2, x2); + XLS_ASSERT_OK_AND_ASSIGN(auto f2, fb2.Build()); + + NodeFingerprintAnalysis analysis; + XLS_ASSERT_OK(analysis.Attach(f1).status()); + uint64_t fp1 = analysis.GetFingerprint(f1->return_value()); + + XLS_ASSERT_OK(analysis.Attach(f2).status()); + uint64_t fp2 = analysis.GetFingerprint(f2->return_value()); + + // These should be equal because (Param0 + Param1) is structurally same in + // both functions. + EXPECT_EQ(fp1, fp2); +} + +} // namespace +} // namespace xls diff --git a/xls/passes/reassociation_pass.cc b/xls/passes/reassociation_pass.cc index a5b9ceebea..a2f78e774f 100644 --- a/xls/passes/reassociation_pass.cc +++ b/xls/passes/reassociation_pass.cc @@ -60,6 +60,7 @@ #include "xls/passes/lazy_dag_cache.h" #include "xls/passes/lazy_node_info.h" #include "xls/passes/lazy_ternary_query_engine.h" +#include "xls/passes/node_fingerprint_analysis.h" #include "xls/passes/optimization_pass.h" #include "xls/passes/pass_base.h" #include "xls/passes/query_engine.h" @@ -1105,22 +1106,32 @@ class Reassociation { XLS_RET_CHECK(elements.op()) << "not associative operation" << elements.node(); VLOG(2) << "Reassociating operation " << elements.node() - << " with elements: [" << elements.ElementsToString() << "]"; + << " with elements: [" << elements.ElementsToString() << "] as " + << (is_signed ? "signed" : "unsigned"); std::vector variable_elements(elements.variables().begin(), elements.variables().end()); // Make sure that variable elements are sorted consistently to ensure that // CSE will be able to merge them. Sort by bit-count, non-reassociativity, - // negatedness then id. + // negatedness, fingerprint, name, then id. + XLS_ASSIGN_OR_RETURN(NodeFingerprintAnalysis * fingerprint, + context_.SharedNodeData(fb_)); auto is_basic_candidate = [&](Node* n) -> bool { auto elem = cache_.GetInfo(n).Get({}); return elem && (!elem->signed_values.is_leaf() || !elem->unsigned_values.is_leaf()); }; auto elements_cmp = [&](const NodeData& a, const NodeData& b) { + if (a.node == b.node) { + // If the nodes are the same, sort by negatedness; everything else we + // might check will be equal. + return (a.needs_negate != b.needs_negate) ? a.needs_negate : false; + } + auto bit_count_comp = a.node->BitCountOrDie() <=> b.node->BitCountOrDie(); if (bit_count_comp != std::strong_ordering::equal) { return bit_count_comp == std::strong_ordering::less; } + // Try to push nodes which we can feasibly reassociate again to the right // since we build the tree left biased further right can be at lower // depths which is good if we end up being able to reassociate again. @@ -1129,11 +1140,64 @@ class Reassociation { if (lhs_is_reassoc_candidate != rhs_is_reassoc_candidate) { return !lhs_is_reassoc_candidate; } - auto id_cmp = a.node->id() <=> b.node->id(); - if (id_cmp != std::strong_ordering::equal) { - return id_cmp == std::strong_ordering::less; + + // Try to push things into the order: params, operations, literals. + bool a_is_param = a.node->Is(); + bool b_is_param = b.node->Is(); + if (a_is_param && b_is_param) { + absl::StatusOr a_param_index = + a.node->function_base()->GetParamIndex(a.node->As()); + absl::StatusOr b_param_index = + b.node->function_base()->GetParamIndex(b.node->As()); + CHECK_OK(a_param_index); + CHECK_OK(b_param_index); + auto param_index_cmp = *a_param_index <=> *b_param_index; + if (param_index_cmp != std::strong_ordering::equal) { + return param_index_cmp == std::strong_ordering::less; + } + } else if (a_is_param) { + return true; + } else if (b_is_param) { + return false; + } + + bool a_is_literal = a.node->Is(); + bool b_is_literal = b.node->Is(); + if (a_is_literal != b_is_literal) { + return !a_is_literal; + } + + // All else being equal, push negations to the right. + if (a.needs_negate != b.needs_negate) { + return a.needs_negate; + } + + // We're out of meaningful ways to sort - so we have to resort to + // arbitrary ordering to keep our results consistent/deterministic. + + // First, try sorting by structural fingerprint; this will be stable even + // as name and ID change. + auto fingerprint_cmp = fingerprint->GetFingerprint(a.node) <=> + fingerprint->GetFingerprint(b.node); + if (fingerprint_cmp != std::strong_ordering::equal) { + return fingerprint_cmp == std::strong_ordering::less; + } + + // Try sorting by name, since names change less often than IDs; if only + // one node has a name, push the named one to the left. + if (a.node->HasAssignedName() && b.node->HasAssignedName()) { + auto name_cmp = a.node->GetNameView() <=> b.node->GetNameView(); + if (name_cmp != std::strong_ordering::equal) { + return name_cmp == std::strong_ordering::greater; + } + } else if (a.node->HasAssignedName()) { + return false; + } else if (b.node->HasAssignedName()) { + return true; } - return (a.needs_negate != b.needs_negate) ? a.needs_negate : false; + + // If we have no other options, sort by ID. + return (a.node->id() <=> b.node->id()) == std::strong_ordering::less; }; absl::c_sort(variable_elements, elements_cmp); std::string associative_sum_name = diff --git a/xls/passes/reassociation_pass_test.cc b/xls/passes/reassociation_pass_test.cc index 3ed88cda2a..be07b978ed 100644 --- a/xls/passes/reassociation_pass_test.cc +++ b/xls/passes/reassociation_pass_test.cc @@ -30,6 +30,7 @@ #include "absl/types/span.h" #include "xls/common/status/matchers.h" #include "xls/common/status/status_macros.h" +#include "xls/dev_tools/remove_identifiers.h" #include "xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h" #include "xls/fuzzer/ir_fuzzer/ir_fuzz_test_library.h" #include "xls/ir/bits.h" @@ -82,6 +83,8 @@ class ReassociationPassTest : public IrTestBase { bool run_result = false; pass.Add>(&run_result); pass.Add(); + pass.Add(); + pass.Add(); XLS_ASSIGN_OR_RETURN( bool compound_result, pass.Run(p, OptimizationPassOptions(), &results, context)); @@ -219,7 +222,7 @@ TEST_F(ReassociationPassTest, ZeroPlusConstant) { XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); ScopedVerifyEquivalence sve(f); ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); - EXPECT_THAT(f->return_value(), m::Sub(m::Literal(), m::Literal())); + EXPECT_THAT(f->return_value(), m::Literal(8)); } TEST_F(ReassociationPassTest, ChainOfConstants) { @@ -749,9 +752,7 @@ TEST_F(ReassociationPassTest, ReassociateMultipleLiterals) { ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); EXPECT_THAT(f->return_value(), m::Add(m::Add(m::Param("a"), m::Param("b")), - m::Add(m::Param("c"), - m::Add(m::Add(m::Literal(42), m::Literal(123)), - m::Literal(10))))); + m::Add(m::Param("c"), m::Literal(175)))); } TEST_F(ReassociationPassTest, SingleLiteralNoReassociate) { @@ -909,9 +910,9 @@ TEST_F(ReassociationPassTest, BalanceEarlyUseIsNotDuplicated) { EXPECT_THAT( f->return_value(), m::Tuple(lhs.node(), - m::Add(m::Add(m::Add(lhs.node(), m::Param("a")), - m::Add(m::Param("rhs0"), m::Param("rhs1"))), - m::Add(m::Param("rhs2"), m::Param("rhs3"))))); + m::Add(m::Add(m::Add(m::Param("a"), m::Param("rhs0")), + m::Add(m::Param("rhs1"), m::Param("rhs2"))), + m::Add(m::Param("rhs3"), lhs.node())))); } TEST_F(ReassociationPassTest, DoubleUseBalanceDoesntChange) { @@ -1026,8 +1027,7 @@ TEST_F(ReassociationPassTest, SubUnderflowZeroExtend5) { XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.Build()); ScopedVerifyEquivalence sve(f); ASSERT_THAT(Run(p.get()), IsOkAndHolds(true)); - EXPECT_THAT(f->return_value(), - m::Sub(m::Param("param"), m::ZeroExt(lit_sub.node()))); + EXPECT_THAT(f->return_value(), m::Sub(m::Param("param"), m::Literal(16))); } TEST_F(ReassociationPassTest, ConcatMultipleValues) {