Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 169 additions & 124 deletions xls/dslx/ir_convert/function_converter.cc

Large diffs are not rendered by default.

29 changes: 23 additions & 6 deletions xls/dslx/ir_convert/function_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,22 @@ class FunctionConverter {
std::variant<BValue, CValue, Channel*, ChannelInterface*, ChannelArray*,
ProcDefInstance*, xls::StateElement*>;

// The `IrValue` for an instance of an impl-based proc. This is basically a
// tuple, except that the members are typically non-BValues (e.g. channels).
// When these are supported by BValues, we can probably use a native tuple.
// The `IrValue` for an instance of an impl-based proc.
struct ProcDefInstance {
const ProcDef* proc_def;
std::vector<IrValue> member_values;
ProcId proc_id;
TypeInfo* type_info;
ParametricEnv env;

std::unique_ptr<ChannelScope> channel_scope;
std::unique_ptr<BuilderBase> builder;

absl::flat_hash_map<std::string, Value> state_init_values;
absl::flat_hash_map<std::string, StateElement*> state_elements;

// We can lower this as a native tuple, if there is ever a need, and if
// channels and channel arrays become native IR values.
absl::flat_hash_map<std::string, IrValue> member_values;
};

// Helper for converting an IR value to its BValue pointer for use in
Expand Down Expand Up @@ -305,6 +315,11 @@ class FunctionConverter {
// Def/DefAlias).
absl::StatusOr<BValue> Use(const AstNode* node) const;

// Visits `source` and propagates its `IrValue` to `dest`. This checks that
// visiting `source` produces an `IrValue`.
absl::Status VisitAndPropagateIrValue(const AstNode* source,
const AstNode* dest);

void SetNodeToIr(const AstNode* node, IrValue value);
void SetNodeToChannelOrArray(const AstNode* node, ChannelOrArray value);
std::optional<IrValue> GetNodeToIr(const AstNode* node) const;
Expand Down Expand Up @@ -536,10 +551,12 @@ class FunctionConverter {
absl::Status HandleProcDef(const ProcDef* proc_def,
const Function* constructor);

absl::Status HandleProcDefConstructor(const ProcDef& proc,
absl::Status HandleProcDefConstructor(const ProcDef* proc_def,
const Function& constructor,
const ParametricEnv& bindings,
ProcBuilder* builder);
TypeInfo* type_info, ProcId proc_id);

absl::Status HandleProcDefSpawn(ProcDefInstance* instance);

// Dereferences the type definition to a struct or impl-style proc definition.
absl::StatusOr<StructDefBase*> DerefStructOrProc(TypeDefinition node);
Expand Down
13 changes: 5 additions & 8 deletions xls/dslx/ir_convert/get_conversion_records.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,23 +343,20 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
absl::Status HandleProcDef(const ProcDef* p) override {
VLOG(5) << "HandleProcDef " << p->ToString();

XLS_ASSIGN_OR_RETURN(Function * constructor,
GetTopProcConstructor(p, type_info_));
XLS_ASSIGN_OR_RETURN(InterpValue initial_state,
type_info_->GetConstExpr(constructor->body()));

VLOG(5) << "Initial state: " << initial_state.ToHumanString();

std::optional<Function*> next_fn = GetProcNextFunction(p);
XLS_RET_CHECK(next_fn.has_value());

// Note that we don't need an `init_value` in a `ProcDef` record because the
// counterpart to that plus the `config` output is in the proc instance
// object.
XLS_ASSIGN_OR_RETURN(
ConversionRecord cr,
MakeConversionRecord(*next_fn, p->owner(), type_info_,
/*bindings=*/ParametricEnv(),
/*proc_id=*/proc_id_factory_.CreateProcId(p),
/*is_top=*/top_ == next_fn,
/*config_record=*/nullptr, initial_state));
/*config_record=*/nullptr,
/*init_value=*/std::nullopt));
records_.push_back(std::move(cr));
return absl::OkStatus();
}
Expand Down
36 changes: 36 additions & 0 deletions xls/dslx/ir_convert/ir_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3054,6 +3054,42 @@ impl Main {
ExpectIr(converted);
}

TEST_F(IrConverterTest, TopProcDefWithIndrectConstructorResult) {
constexpr std::string_view program = R"(
#![feature(explicit_state_access)]

proc Main {
c_in: chan<u32> in,
c_out: chan<u32> out,
i: u32,
j: u32,
}

impl Main {
fn new(c_in: chan<u32> in, c_out: chan<u32> out) -> Self {
let result = Main { c_in: c_in, c_out: c_out, i: 1, j: 2 };
let foo = result;
foo
}

fn next(self) {
let i_val = read(self.i);
let j_val = read(self.j);
let (tok, v) = recv(join(), self.c_in);
let res = i_val + j_val + v;
let tok = send(tok, self.c_out, res);
write(self.i, res);
}
}
)";

auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
ConvertOneFunctionForTest(program, "Main", import_data));
ExpectIr(converted);
}

TEST_F(IrConverterTest, TopProcDefWithNoConstructorFails) {
constexpr std::string_view program = R"(
#![feature(explicit_state_access)]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package test_module

file_number 0 "test_module.x"

top proc __test_module__Main_next<_c_in: bits[32] in, _c_out: bits[32] out>(__i: bits[32], __j: bits[32], init={1, 2}) {
chan_interface _c_in(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface _c_out(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
literal.2: bits[1] = literal(value=1, id=2)
literal.3: bits[1] = literal(value=0, id=3)
not.9: bits[1] = not(literal.2, id=9)
not.10: bits[1] = not(literal.3, id=10)
literal.6: bits[1] = literal(value=0, id=6)
__i__1: bits[32] = state_read(state_element=__i, predicate=literal.2, id=14)
__j__1: bits[32] = state_read(state_element=__j, predicate=literal.2, id=22)
after_all.25: token = after_all(id=25)
__token: token = literal(value=token, id=1)
or.11: bits[1] = or(not.9, not.10, id=11)
not.17: bits[1] = not(literal.2, id=17)
not.18: bits[1] = not(literal.6, id=18)
i_val: bits[32] = identity(__i__1, id=15)
j_val: bits[32] = identity(__j__1, id=23)
receive.26: (token, bits[32]) = receive(after_all.25, predicate=literal.2, channel=_c_in, id=26)
assert.12: token = assert(__token, or.11, message="State element read after read in same activation.", id=12)
or.19: bits[1] = or(not.17, not.18, id=19)
not.33: bits[1] = not(literal.2, id=33)
or.13: bits[1] = or(literal.3, literal.2, id=13)
literal.4: bits[1] = literal(value=0, id=4)
add.30: bits[32] = add(i_val, j_val, id=30)
v: bits[32] = tuple_index(receive.26, index=1, id=29)
assert.20: token = assert(assert.12, or.19, message="State element read after read in same activation.", id=20)
or.34: bits[1] = or(not.33, or.13, id=34)
not.36: bits[1] = not(literal.2, id=36)
not.37: bits[1] = not(literal.4, id=37)
tok: token = tuple_index(receive.26, index=0, id=28)
res: bits[32] = add(add.30, v, id=31)
assert.35: token = assert(assert.20, or.34, message="State element written before read in same activation.", id=35)
or.38: bits[1] = or(not.36, not.37, id=38)
__i: bits[32] = state_read(state_element=__i, id=5)
literal.7: bits[1] = literal(value=0, id=7)
__j: bits[32] = state_read(state_element=__j, id=8)
or.16: bits[1] = or(literal.3, literal.2, id=16)
or.21: bits[1] = or(literal.6, literal.2, id=21)
or.24: bits[1] = or(literal.6, literal.2, id=24)
tuple_index.27: token = tuple_index(receive.26, index=0, id=27)
tok__1: token = send(tok, res, predicate=literal.2, channel=_c_out, id=32)
assert.39: token = assert(assert.35, or.38, message="State element written after write in same activation.", id=39)
or.40: bits[1] = or(literal.4, literal.2, id=40)
next_value.41: () = next_value(param=__i, value=res, predicate=literal.2, id=41)
tuple.42: () = tuple(id=42)
tuple.43: () = tuple(id=43)
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,34 @@ top proc __test_module__Main_next<_c_in: bits[32] in, _c_out: bits[32] out>(__i:
chan_interface _c_out(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
literal.2: bits[1] = literal(value=1, id=2)
literal.3: bits[1] = literal(value=0, id=3)
after_all.15: token = after_all(id=15)
not.7: bits[1] = not(literal.2, id=7)
not.8: bits[1] = not(literal.3, id=8)
__i__1: bits[32] = state_read(state_element=__i, predicate=literal.2, id=12)
receive.16: (token, bits[32]) = receive(after_all.15, predicate=literal.2, channel=_c_in, id=16)
after_all.14: token = after_all(id=14)
not.6: bits[1] = not(literal.2, id=6)
not.7: bits[1] = not(literal.3, id=7)
__i__1: bits[32] = state_read(state_element=__i, predicate=literal.2, id=11)
receive.15: (token, bits[32]) = receive(after_all.14, predicate=literal.2, channel=_c_in, id=15)
__token: token = literal(value=token, id=1)
or.9: bits[1] = or(not.7, not.8, id=9)
not.23: bits[1] = not(literal.2, id=23)
or.11: bits[1] = or(literal.3, literal.2, id=11)
or.8: bits[1] = or(not.6, not.7, id=8)
not.22: bits[1] = not(literal.2, id=22)
or.10: bits[1] = or(literal.3, literal.2, id=10)
literal.4: bits[1] = literal(value=0, id=4)
i_val: bits[32] = identity(__i__1, id=13)
j: bits[32] = tuple_index(receive.16, index=1, id=19)
assert.10: token = assert(__token, or.9, message="State element read after read in same activation.", id=10)
or.24: bits[1] = or(not.23, or.11, id=24)
not.26: bits[1] = not(literal.2, id=26)
not.27: bits[1] = not(literal.4, id=27)
i_val: bits[32] = identity(__i__1, id=12)
j: bits[32] = tuple_index(receive.15, index=1, id=18)
assert.9: token = assert(__token, or.8, message="State element read after read in same activation.", id=9)
or.23: bits[1] = or(not.22, or.10, id=23)
not.25: bits[1] = not(literal.2, id=25)
not.26: bits[1] = not(literal.4, id=26)
tok: token = tuple_index(receive.15, index=0, id=17)
add.19: bits[32] = add(i_val, j, id=19)
assert.24: token = assert(assert.9, or.23, message="State element written before read in same activation.", id=24)
or.27: bits[1] = or(not.25, not.26, id=27)
__i: bits[32] = state_read(state_element=__i, id=5)
tok: token = tuple_index(receive.16, index=0, id=18)
add.20: bits[32] = add(i_val, j, id=20)
assert.25: token = assert(assert.10, or.24, message="State element written before read in same activation.", id=25)
or.28: bits[1] = or(not.26, not.27, id=28)
add.22: bits[32] = add(i_val, j, id=22)
tuple.6: (bits[32]) = tuple(__i, id=6)
or.14: bits[1] = or(literal.3, literal.2, id=14)
tuple_index.17: token = tuple_index(receive.16, index=0, id=17)
tok__1: token = send(tok, add.20, predicate=literal.2, channel=_c_out, id=21)
assert.29: token = assert(assert.25, or.28, message="State element written after write in same activation.", id=29)
or.30: bits[1] = or(literal.4, literal.2, id=30)
next_value.31: () = next_value(param=__i, value=add.22, predicate=literal.2, id=31)
add.21: bits[32] = add(i_val, j, id=21)
or.13: bits[1] = or(literal.3, literal.2, id=13)
tuple_index.16: token = tuple_index(receive.15, index=0, id=16)
tok__1: token = send(tok, add.19, predicate=literal.2, channel=_c_out, id=20)
assert.28: token = assert(assert.24, or.27, message="State element written after write in same activation.", id=28)
or.29: bits[1] = or(literal.4, literal.2, id=29)
next_value.30: () = next_value(param=__i, value=add.21, predicate=literal.2, id=30)
tuple.31: () = tuple(id=31)
tuple.32: () = tuple(id=32)
tuple.33: () = tuple(id=33)
}
25 changes: 3 additions & 22 deletions xls/dslx/type_system_v2/constant_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,8 @@ class Visitor : public AstNodeVisitorWithDefault {
return absl::OkStatus();
}

// For a StructInstance node that is creating an impl-style proc, we store a
// tuple of the initial state values as the constexpr value in TypeInfo.
// This is equivalent to the result of the 'init' block in legacy procs.
// For a StructInstance node that is creating an impl-style proc, we collect
// the initial state values.
const ProcType& type = type_.AsProc();
std::vector<InterpValue> state_init_values;
for (const auto& [member_name, initializer] :
Expand Down Expand Up @@ -720,25 +719,7 @@ class Visitor : public AstNodeVisitorWithDefault {
file_table_);
}

state_init_values.push_back(std::move(*value));
}

InterpValue value = InterpValue::MakeTuple(std::move(state_init_values));
ti_->NoteConstExpr(node, value);
VLOG(6) << "Storing value " << value.ToHumanString()
<< " for proc initializer " << node->ToString();

// Propagate the proc "value" through the statement and/or statement block
// that yields it. This way IR conversion can easily say "give me the
// constant value for the body of the proc constructor."
for (AstNode* parent = node->parent();
parent != nullptr && (parent->kind() == AstNodeKind::kStatement ||
parent->kind() == AstNodeKind::kStatementBlock);
parent = parent->parent()) {
VLOG(6) << "Propagating proc initializer value for expr `"
<< node->ToString() << "` to ancestor of kind "
<< AstNodeKindToString(parent->kind());
ti_->NoteConstExpr(parent, value);
ti_->NoteConstExpr(initializer, *value);
}

return absl::OkStatus();
Expand Down
Loading