Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions xls/ir/ir_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,11 @@ bool StateReadMatcher::MatchAndExplain(
*listener << " has incorrect label";
return false;
}
if (predicate_.has_value() &&
!predicate_->MatchAndExplain(node->As<xls::StateRead>()->predicate(),
listener)) {
return false;
}
return true;
}

Expand All @@ -646,6 +651,12 @@ void StateReadMatcher::DescribeTo(::std::ostream* os) const {
label_->DescribeTo(&ss);
additional_fields.push_back(ss.str());
}
if (predicate_.has_value()) {
std::stringstream ss;
ss << "predicate=";
predicate_->DescribeTo(&ss);
additional_fields.push_back(ss.str());
}
DescribeToHelper(os, additional_fields);
}

Expand Down
17 changes: 15 additions & 2 deletions xls/ir/ir_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -1269,16 +1269,20 @@ inline ::testing::Matcher<const ::xls::Node*> OutputPort(
// EXPECT_THAT(x, m::StateRead());
// EXPECT_THAT(x, m::StateRead("x"));
// EXPECT_THAT(x, m::StateRead(HasSubstr("substr")));
// EXPECT_THAT(x, m::StateRead("x", /*predicate=*/m::Param("pred")));
//
class StateReadMatcher : public NodeMatcher {
public:
explicit StateReadMatcher(
std::optional<::testing::Matcher<const std::string>> state_element_name,
std::optional<::testing::Matcher<const std::optional<std::string>&>>
label = std::nullopt)
label = std::nullopt,
std::optional<::testing::Matcher<std::optional<Node*>>> predicate =
std::nullopt)
: NodeMatcher(Op::kStateRead, /*operands=*/{}),
state_element_name_(std::move(state_element_name)),
label_(std::move(label)) {}
label_(std::move(label)),
predicate_(std::move(predicate)) {}

bool MatchAndExplain(const Node* node,
::testing::MatchResultListener* listener) const override;
Expand All @@ -1287,6 +1291,7 @@ class StateReadMatcher : public NodeMatcher {
private:
std::optional<::testing::Matcher<const std::string>> state_element_name_;
std::optional<::testing::Matcher<const std::optional<std::string>&>> label_;
std::optional<::testing::Matcher<std::optional<Node*>>> predicate_;
};

template <typename T>
Expand Down Expand Up @@ -1324,6 +1329,14 @@ inline ::testing::Matcher<const ::xls::Node*> StateRead() {
return ::xls::op_matchers::StateReadMatcher(std::nullopt, std::nullopt);
}

inline ::testing::Matcher<const ::xls::Node*> StateRead(
::testing::Matcher<const std::string> name,
::testing::Matcher<const std::optional<std::string>&> label,
::testing::Matcher<std::optional<Node*>> predicate) {
return ::xls::op_matchers::StateReadMatcher(std::move(name), std::move(label),
std::move(predicate));
}

// Next matcher. Supported forms:
//
// EXPECT_THAT(x, m::Next());
Expand Down
84 changes: 57 additions & 27 deletions xls/ir/proc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -930,26 +930,46 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
Proc::StateElementTransformer& transform) {
StateElement* old_state_element = old_state_read->state_element();
std::string orig_name(old_state_element->name());
std::string orig_read_name(old_state_read->GetNameView());
XLS_ASSIGN_OR_RETURN(std::optional<Node*> read_predicate,
transform.TransformReadPredicate(this, old_state_read));

absl::Span<StateRead* const> all_old_reads =
GetStateReadsByStateElement(old_state_element);
absl::flat_hash_map<StateRead*, std::string> orig_read_names;
for (StateRead* old_read : all_old_reads) {
orig_read_names[old_read] = std::string(old_read->GetNameView());
}

XLS_ASSIGN_OR_RETURN(
StateRead * new_state_read,
AppendStateElement(absl::StrFormat("TEMP_NAME__%s__", orig_name),
init_value, read_predicate,
/*next_state=*/std::nullopt));
new_state_read->SetLoc(old_state_read->loc());
if (old_state_read->state_element()->non_synthesizable()) {
new_state_read->state_element()->SetNonSynthesizable();
StateElement * new_state_element,
AppendUnreadStateElement(absl::StrFormat("TEMP_NAME__%s__", orig_name),
init_value, old_state_read->loc()));
if (old_state_element->non_synthesizable()) {
new_state_element->SetNonSynthesizable();
}
StateElement* new_state_element = new_state_read->state_element();
std::string temp_name = new_state_element->name();

XLS_ASSIGN_OR_RETURN(
Node * new_state_value,
transform.TransformStateRead(this, new_state_read, old_state_read));
std::vector<std::pair<Node*, Node*>> to_replace{
{old_state_read, new_state_value}};
absl::flat_hash_map<StateRead*, StateRead*> old_to_new_read;
StateRead* return_state_read = nullptr;
std::vector<std::pair<Node*, Node*>> to_replace;

for (StateRead* old_read : all_old_reads) {
XLS_ASSIGN_OR_RETURN(std::optional<Node*> read_predicate,
transform.TransformReadPredicate(this, old_read));
XLS_ASSIGN_OR_RETURN(StateRead * new_read,
MakeNodeWithName<StateRead>(
old_read->loc(), new_state_element, read_predicate,
old_read->label(), temp_name));
state_reads_[new_state_element].push_back(new_read);
old_to_new_read[old_read] = new_read;

if (old_read == old_state_read) {
return_state_read = new_read;
}

XLS_ASSIGN_OR_RETURN(Node * new_state_value, transform.TransformStateRead(
this, new_read, old_read));
to_replace.push_back({old_read, new_state_value});
}

struct NextTransformation {
Next* old_next;
Node* new_value;
Expand All @@ -959,14 +979,18 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
for (Next* nxt : next_values(old_state_element)) {
NextTransformation& new_next = transforms.emplace_back();
new_next.old_next = nxt;
XLS_ASSIGN_OR_RETURN(new_next.new_value, transform.TransformNextValue(
this, new_state_read, nxt));
XLS_RET_CHECK(new_next.new_value->GetType() == new_state_read->GetType())
StateRead* corresponding_new_read =
old_to_new_read.at(nxt->state_read()->As<StateRead>());
XLS_ASSIGN_OR_RETURN(
new_next.new_value,
transform.TransformNextValue(this, corresponding_new_read, nxt));
XLS_RET_CHECK(new_next.new_value->GetType() ==
corresponding_new_read->GetType())
<< "New value is not compatible type. Expected: "
<< new_state_read->GetType() << " got " << new_next.new_value;
<< corresponding_new_read->GetType() << " got " << new_next.new_value;
XLS_ASSIGN_OR_RETURN(
new_next.new_predicate,
transform.TransformNextPredicate(this, new_state_read, nxt));
transform.TransformNextPredicate(this, corresponding_new_read, nxt));
}

// We've transformed all the graph elements. Start replacing them.
Expand All @@ -977,24 +1001,30 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
auto orig_storage = state_elements_.extract(orig_name);
orig_storage.key() = to_remove_name;
old_state_element->SetName(to_remove_name);
old_state_read->SetName(to_remove_name);
for (StateRead* old_read : all_old_reads) {
old_read->SetName(to_remove_name);
}
CHECK(state_elements_.insert(std::move(orig_storage)).inserted);

// Take over the old state element & read names.
auto new_storage = state_elements_.extract(temp_name);
new_storage.key() = orig_name;
new_state_element->SetName(orig_name);
new_state_read->SetNameDirectly(orig_read_name);
for (auto& [old_read, new_read] : old_to_new_read) {
new_read->SetNameDirectly(orig_read_names.at(old_read));
}
CHECK(state_elements_.insert(std::move(new_storage)).inserted);

// Identity-ify the old next nodes and create new ones.
for (const NextTransformation& nt : transforms) {
// Make the next
StateRead* corresponding_new_read =
old_to_new_read.at(nt.old_next->state_read()->As<StateRead>());
XLS_ASSIGN_OR_RETURN(
Next * nxt,
MakeNodeWithName<Next>(nt.old_next->loc(), new_state_read, nt.new_value,
nt.new_predicate, nt.old_next->label(),
nt.old_next->GetName()));
MakeNodeWithName<Next>(nt.old_next->loc(), corresponding_new_read,
nt.new_value, nt.new_predicate,
nt.old_next->label(), nt.old_next->GetName()));
to_replace.push_back({nt.old_next, nxt});
// Identity-ify the old next.
XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber(
Expand All @@ -1011,7 +1041,7 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
},
/*replace_implicit_uses=*/false));
}
return new_state_read;
return return_state_read;
}

absl::Status Proc::InternalRebuildSideTables() {
Expand Down
69 changes: 69 additions & 0 deletions xls/ir/proc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,75 @@ TEST_F(ProcTest, TransformStateElement) {
EXPECT_THAT(user.node(), m::Tuple(m::Neg(new_st)));
}

TEST_F(ProcTest, TransformStateElementMultipleReads) {
auto p = CreatePackage();
TokenlessProcBuilder pb(TestName(), "tkn", p.get());
auto st = pb.StateElement("st", UBits(0b1010, 4));
auto cond = pb.StateElement("cond", UBits(0, 1));
auto add_st = pb.Next(st, pb.Add(st, pb.Literal(UBits(1, 4))), cond);
pb.Next(cond, pb.Not(cond));

XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());

// Manually add a second read for 'st'
StateElement* st_elem = proc->GetStateElement(0);
XLS_ASSERT_OK_AND_ASSIGN(
StateRead * st_read2,
proc->MakeNodeWithName<StateRead>(SourceInfo(), st_elem,
/*predicate=*/std::nullopt,
/*label=*/std::nullopt, "st_read2"));
XLS_ASSERT_OK_AND_ASSIGN(
Node * lit_sub,
proc->MakeNode<Literal>(SourceInfo(), Value(UBits(2, 4))));
XLS_ASSERT_OK_AND_ASSIGN(
Node * sub_st2,
proc->MakeNode<BinOp>(SourceInfo(), st_read2, lit_sub, Op::kSub));
XLS_ASSERT_OK_AND_ASSIGN(
Next * next_st2,
proc->MakeNodeWithName<Next>(SourceInfo(), st_read2, sub_st2,
/*predicate=*/std::nullopt,
/*label=*/std::nullopt, "next_st2"));
XLS_ASSERT_OK(proc->RebuildSideTables());

// Verify side tables
EXPECT_EQ(proc->GetStateReadsByStateElement(st_elem).size(), 2);

// Test transformer (invert param)
struct TestTransformer : public Proc::StateElementTransformer {
public:
absl::StatusOr<Node*> TransformStateRead(
Proc* proc, StateRead* new_state_read,
StateRead* old_state_read) override {
return proc->MakeNode<UnOp>(new_state_read->loc(), new_state_read,
Op::kNeg);
}
absl::StatusOr<Node*> TransformNextValue(Proc* proc,
StateRead* new_state_read,
Next* old_next) override {
return proc->MakeNode<UnOp>(old_next->value()->loc(), old_next->value(),
Op::kNeg);
}
};
TestTransformer tt;
XLS_ASSERT_OK_AND_ASSIGN(
StateRead * new_st,
proc->TransformStateElement(st.node()->As<StateRead>(),
Value(UBits(0b0101, 4)), tt));

// Verify the first read and its next were transformed
EXPECT_THAT(new_st, m::StateRead("st"));
EXPECT_THAT(add_st.node(), m::Next(st.node(), st.node(), cond.node()));

// Verify the second read and its next were transformed
Node* new_st_read2 = FindNode("st_read2", proc);
ASSERT_NE(new_st_read2, nullptr);
EXPECT_NE(new_st_read2, new_st);
EXPECT_THAT(new_st_read2->As<StateRead>()->state_element(),
new_st->state_element());

EXPECT_THAT(next_st2, m::Next(st_read2, st_read2));
}

class ScheduledProcTest : public IrTestBase {
protected:
absl::StatusOr<ScheduledProc*> CreateScheduledProc(Package* p) {
Expand Down
3 changes: 3 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2507,6 +2507,7 @@ xls_pass(
"//xls/ir:type",
"//xls/ir:value",
"//xls/ir:value_utils",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:btree",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand Down Expand Up @@ -2565,6 +2566,7 @@ xls_pass(
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)

Expand All @@ -2576,6 +2578,7 @@ xls_pass(
deps = [
":optimization_pass",
":pass_base",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/ir",
"//xls/ir:bits",
Expand Down
35 changes: 25 additions & 10 deletions xls/passes/array_untuple_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,32 @@ absl::StatusOr<absl::flat_hash_set<Node*>> FindExternalGroups(
// Don't mess with params that are only used in identity updates. Would
// infinite loop otherwise since we don't remove these very often.
for (StateElement* state_element : f->AsProcOrDie()->StateElements()) {
StateRead* state_read =
f->AsProcOrDie()->GetStateReadByStateElement(state_element);
if (absl::c_all_of(state_read->users(), [&](Node* n) -> bool {
if (n->Is<Next>()) {
absl::Span<StateRead* const> state_reads =
f->AsProcOrDie()->GetStateReadsByStateElement(state_element);
bool all_reads_identity = true;
for (StateRead* state_read : state_reads) {
if (!absl::c_all_of(state_read->users(), [&](Node* n) -> bool {
if (!n->Is<Next>()) {
return false;
}
Next* nxt = n->As<Next>();
return nxt->state_read() == nxt->value() &&
nxt->state_read() == state_read;
}
return false;
})) {
excluded.insert(groups.Find(state_read));
if (nxt->state_element() != state_element) {
return false;
}
if (!nxt->value()->Is<StateRead>()) {
return false;
}
return nxt->value()->As<StateRead>()->state_element() ==
state_element;
})) {
all_reads_identity = false;
break;
}
}
if (all_reads_identity) {
for (StateRead* state_read : state_reads) {
excluded.insert(groups.Find(state_read));
}
}
}
}
Expand Down
Loading
Loading