diff --git a/xls/build_rules/tests/fuzz_test_example.x b/xls/build_rules/tests/fuzz_test_example.x
index 9aa4522439..d0ebe8b61d 100644
--- a/xls/build_rules/tests/fuzz_test_example.x
+++ b/xls/build_rules/tests/fuzz_test_example.x
@@ -12,4 +12,5 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#[fuzz_test(domains=`u32:0..100, u32:0..100`)]
fn my_fuzz_property(x: u32, y: u32) -> bool { x + y == y + x }
diff --git a/xls/dev_tools/BUILD b/xls/dev_tools/BUILD
index ef5da7f5a4..e7ee4f77ee 100644
--- a/xls/dev_tools/BUILD
+++ b/xls/dev_tools/BUILD
@@ -119,6 +119,7 @@ cc_library(
srcs = ["extract_interface.cc"],
hdrs = ["extract_interface.h"],
deps = [
+ "//xls/common:attribute_data",
"//xls/ir",
"//xls/ir:channel",
"//xls/ir:register",
@@ -127,6 +128,7 @@ cc_library(
"//xls/ir:xls_ir_interface_cc_proto",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
+ "@com_google_protobuf//:protobuf",
],
)
@@ -154,6 +156,7 @@ cc_test(
srcs = ["extract_interface_test.cc"],
deps = [
":extract_interface",
+ "//xls/common:attribute_data",
"//xls/common:proto_test_utils",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
diff --git a/xls/dev_tools/extract_interface.cc b/xls/dev_tools/extract_interface.cc
index 6e6597989a..953c8ad792 100644
--- a/xls/dev_tools/extract_interface.cc
+++ b/xls/dev_tools/extract_interface.cc
@@ -14,8 +14,13 @@
#include "xls/dev_tools/extract_interface.h"
+#include
+#include
+
#include "absl/log/check.h"
#include "absl/status/statusor.h"
+#include "google/protobuf/text_format.h"
+#include "xls/common/attribute_data.h"
#include "xls/ir/block.h"
#include "xls/ir/channel.h"
#include "xls/ir/function.h"
@@ -61,6 +66,33 @@ PackageInterfaceProto::Function ExtractFunctionInterface(Function* func) {
AddNamed(proto.add_parameters(), param);
}
*proto.mutable_result_type() = func->GetType()->return_type()->ToProto();
+
+ if (func->HasAttribute(AttributeKind::kFuzzTest)) {
+ for (const auto& attr : func->attributes()) {
+ if (attr.kind() == AttributeKind::kFuzzTest) {
+ for (const auto& arg : attr.args()) {
+ CHECK(std::holds_alternative(
+ arg))
+ << "kFuzzTest argument must be StringKeyValueArgument";
+ const auto& skv =
+ std::get(arg);
+ CHECK_EQ(skv.first, "domains")
+ << "kFuzzTest only supports 'domains' argument";
+ CHECK(skv.is_backticked)
+ << "kFuzzTest domains argument must be backticked";
+
+ std::string proto_str = skv.second;
+ PackageInterfaceProto::Function temp_func;
+ CHECK(google::protobuf::TextFormat::ParseFromString(proto_str, &temp_func))
+ << "Failed to parse fuzz test domains proto from attribute";
+ proto.mutable_parameter_domains()->CopyFrom(
+ temp_func.parameter_domains());
+ }
+ break;
+ }
+ }
+ }
+
return proto;
}
diff --git a/xls/dev_tools/extract_interface_test.cc b/xls/dev_tools/extract_interface_test.cc
index 77fd46e835..9808411bb0 100644
--- a/xls/dev_tools/extract_interface_test.cc
+++ b/xls/dev_tools/extract_interface_test.cc
@@ -15,9 +15,12 @@
#include "xls/dev_tools/extract_interface.h"
#include
+#include
+#include
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "xls/common/attribute_data.h"
#include "xls/common/proto_test_utils.h"
#include "xls/common/status/matchers.h"
#include "xls/ir/bits.h"
@@ -164,5 +167,102 @@ TEST_F(ExtractInterfaceTest, BasicBlock) {
)pb"));
}
+TEST_F(ExtractInterfaceTest, FuzzTestFunction) {
+ constexpr std::string_view kIr = R"(
+package test
+
+#[fuzz_test(domains = `parameter_domains { range { min { bits { bit_count: 32 data: "\000" } } max { bits { bit_count: 32 data: "\012" } } } } parameter_domains { arbitrary: true }`)]
+fn f(x: bits[32], y: bits[32]) -> bits[32] {
+ ret x: bits[32] = param(name=x)
+}
+)";
+ XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(kIr));
+
+ PackageInterfaceProto proto = ExtractPackageInterface(p.get());
+
+ ASSERT_EQ(proto.functions().size(), 1);
+ const auto& func_proto = proto.functions(0);
+
+ ASSERT_EQ(func_proto.parameter_domains().size(), 2);
+ EXPECT_TRUE(func_proto.parameter_domains(1).arbitrary());
+ EXPECT_TRUE(func_proto.parameter_domains(0).has_range());
+}
+
+TEST_F(ExtractInterfaceTest, FuzzTestFunctionNoDomains) {
+ constexpr std::string_view kIr = R"(
+package test
+
+#[fuzz_test]
+fn f(x: bits[32]) -> bits[32] {
+ ret x: bits[32] = param(name=x)
+}
+)";
+ XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(kIr));
+
+ PackageInterfaceProto proto = ExtractPackageInterface(p.get());
+
+ ASSERT_EQ(proto.functions().size(), 1);
+ const auto& func_proto = proto.functions(0);
+
+ EXPECT_THAT(func_proto.parameter_domains(), testing::IsEmpty());
+}
+
+TEST_F(ExtractInterfaceTest, FuzzTestFunctionInvalidArgKeyManual) {
+ VerifiedPackage p("test_package");
+ Function* f;
+ {
+ FunctionBuilder fb("f", &p);
+ fb.Param("x", p.GetBitsType(32));
+ XLS_ASSERT_OK_AND_ASSIGN(f, fb.Build());
+ }
+
+ std::vector args;
+ args.push_back(AttributeData::StringKeyValueArgument{
+ .first = "invalid", .second = "value", .is_backticked = true});
+ f->AddAttribute(AttributeData(AttributeKind::kFuzzTest, std::move(args)));
+
+ EXPECT_DEATH(ExtractPackageInterface(&p),
+ "kFuzzTest only supports 'domains' argument");
+}
+
+TEST_F(ExtractInterfaceTest, PackageWithMixOfFuzzAndNonFuzzFunctions) {
+ constexpr std::string_view kIr = R"(
+package test
+
+#[fuzz_test(domains = `parameter_domains { arbitrary: true }`)]
+fn fuzz_me(x: bits[32]) -> bits[32] {
+ ret x: bits[32] = param(name=x)
+}
+
+fn dont_fuzz_me(x: bits[32]) -> bits[32] {
+ ret x: bits[32] = param(name=x)
+}
+)";
+ XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(kIr));
+
+ PackageInterfaceProto proto = ExtractPackageInterface(p.get());
+
+ ASSERT_EQ(proto.functions().size(), 2);
+
+ const PackageInterfaceProto::Function* fuzz_proto = nullptr;
+ const PackageInterfaceProto::Function* non_fuzz_proto = nullptr;
+
+ for (const auto& f : proto.functions()) {
+ if (f.base().name() == "fuzz_me") {
+ fuzz_proto = &f;
+ } else if (f.base().name() == "dont_fuzz_me") {
+ non_fuzz_proto = &f;
+ }
+ }
+
+ ASSERT_NE(fuzz_proto, nullptr);
+ ASSERT_NE(non_fuzz_proto, nullptr);
+
+ EXPECT_EQ(fuzz_proto->parameter_domains().size(), 1);
+ EXPECT_TRUE(fuzz_proto->parameter_domains(0).arbitrary());
+
+ EXPECT_TRUE(non_fuzz_proto->parameter_domains().empty());
+}
+
} // namespace
} // namespace xls
diff --git a/xls/dslx/ir_convert/BUILD b/xls/dslx/ir_convert/BUILD
index 3b154dca43..5377f864d5 100644
--- a/xls/dslx/ir_convert/BUILD
+++ b/xls/dslx/ir_convert/BUILD
@@ -335,6 +335,7 @@ cc_library(
":ir_conversion_utils",
":proc_config_ir_converter",
":proc_scoped_channel_scope",
+ "//xls/common:attribute_data",
"//xls/common:visitor",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
@@ -383,6 +384,7 @@ cc_library(
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
+ "@com_google_protobuf//:protobuf",
],
)
@@ -395,6 +397,7 @@ cc_test(
":convert_options",
":function_converter",
":test_utils",
+ "//xls/common:attribute_data",
"//xls/common:proto_test_utils",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
@@ -404,6 +407,8 @@ cc_test(
"//xls/dslx/frontend:ast",
"//xls/ir",
"//xls/ir:xls_ir_interface_cc_proto",
+ "@com_google_absl//absl/status",
+ "@com_google_absl//absl/types:span",
"@googletest//:gtest",
],
)
diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc
index 37828697f4..3eaee4eec1 100644
--- a/xls/dslx/ir_convert/function_converter.cc
+++ b/xls/dslx/ir_convert/function_converter.cc
@@ -40,6 +40,8 @@
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
+#include "google/protobuf/text_format.h"
+#include "xls/common/attribute_data.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/common/visitor.h"
@@ -3580,6 +3582,12 @@ absl::Status FunctionConverter::HandleFunction(
VLOG(5) << "Built function: " << ir_fn->name();
XLS_RETURN_IF_ERROR(VerifyFunction(ir_fn));
+ XLS_ASSIGN_OR_RETURN(std::optional fuzz_test_attr,
+ LowerFuzzTestDomains(node));
+ if (fuzz_test_attr.has_value()) {
+ ir_fn->AddAttribute(std::move(*fuzz_test_attr));
+ }
+
// If it's a public fallible function, or it's the entry function for the
// package, we make a wrapper so that the external world (e.g. JIT, verilog
// module) doesn't need to take implicit token arguments.
@@ -3605,6 +3613,96 @@ absl::Status FunctionConverter::HandleFunction(
return absl::OkStatus();
}
+absl::Status FunctionConverter::LowerDomainExpr(
+ Expr* expr, PackageInterfaceProto::FuzzTestDomain* proto) {
+ if (expr->kind() == AstNodeKind::kXlsTuple &&
+ dynamic_cast(expr)->empty()) {
+ proto->set_arbitrary(true);
+ return absl::OkStatus();
+ }
+ if (expr->kind() == AstNodeKind::kRange) {
+ Range* range_node = static_cast(expr);
+
+ XLS_ASSIGN_OR_RETURN(InterpValue min_val,
+ current_type_info_->GetConstExpr(range_node->start()));
+ XLS_ASSIGN_OR_RETURN(InterpValue max_val,
+ current_type_info_->GetConstExpr(range_node->end()));
+
+ XLS_ASSIGN_OR_RETURN(Value ir_min, InterpValueToValue(min_val));
+ XLS_ASSIGN_OR_RETURN(Value ir_max, InterpValueToValue(max_val));
+
+ XLS_ASSIGN_OR_RETURN(ValueProto min_proto, ir_min.AsProto());
+ XLS_ASSIGN_OR_RETURN(ValueProto max_proto, ir_max.AsProto());
+
+ auto* range_proto = proto->mutable_range();
+ *range_proto->mutable_min() = std::move(min_proto);
+ *range_proto->mutable_max() = std::move(max_proto);
+ return absl::OkStatus();
+ }
+ if (expr->kind() == AstNodeKind::kArray) {
+ Array* array_node = static_cast(expr);
+
+ auto* element_of_proto = proto->mutable_element_of();
+ for (Expr* member : array_node->members()) {
+ XLS_ASSIGN_OR_RETURN(InterpValue val,
+ current_type_info_->GetConstExpr(member));
+ XLS_ASSIGN_OR_RETURN(Value ir_val, InterpValueToValue(val));
+ XLS_ASSIGN_OR_RETURN(ValueProto val_proto, ir_val.AsProto());
+ *element_of_proto->add_values() = std::move(val_proto);
+ }
+ return absl::OkStatus();
+ }
+ if (expr->kind() == AstNodeKind::kXlsTuple) {
+ XlsTuple* tuple_node = static_cast(expr);
+
+ auto* tuple_proto = proto->mutable_tuple();
+ for (Expr* member : tuple_node->members()) {
+ XLS_RETURN_IF_ERROR(LowerDomainExpr(member, tuple_proto->add_elements()));
+ }
+ return absl::OkStatus();
+ }
+ return absl::UnimplementedError(
+ absl::StrCat("Unsupported fuzztest domain type: ", expr->ToString()));
+}
+
+absl::StatusOr>
+FunctionConverter::LowerFuzzTestDomains(Function* node) {
+ if (node->parent() != nullptr &&
+ node->parent()->kind() == AstNodeKind::kFuzzTestFunction) {
+ FuzzTestFunction* ft = static_cast(node->parent());
+
+ if (ft->domains().has_value()) {
+ XlsTuple* domains_tuple = *ft->domains();
+ // We use a dummy Function proto here solely to get the
+ // `parameter_domains` field name wrapper in the serialized text proto.
+ // This will allow clients to easily parse the string back into a Function
+ // proto and recover the domains therein.
+ PackageInterfaceProto::Function temp_func;
+
+ for (Expr* domain_expr : domains_tuple->members()) {
+ PackageInterfaceProto::FuzzTestDomain* domain_proto =
+ temp_func.add_parameter_domains();
+
+ XLS_RETURN_IF_ERROR(LowerDomainExpr(domain_expr, domain_proto));
+ }
+
+ std::string proto_str;
+ google::protobuf::TextFormat::Printer printer;
+ printer.SetSingleLineMode(true);
+ XLS_RET_CHECK(printer.PrintToString(temp_func, &proto_str));
+
+ std::vector args;
+ args.push_back(
+ AttributeData::StringKeyValueArgument{.first = "domains",
+ .second = std::move(proto_str),
+ .is_backticked = true});
+
+ return AttributeData(AttributeKind::kFuzzTest, std::move(args));
+ }
+ }
+ return std::nullopt;
+}
+
absl::Status FunctionConverter::HandleSpawn(const Spawn* node) {
VLOG(5) << "HandleSpawn: " << node->ToString();
if (!options_.lower_to_proc_scoped_channels) {
diff --git a/xls/dslx/ir_convert/function_converter.h b/xls/dslx/ir_convert/function_converter.h
index a697860dc7..9e10b89ac3 100644
--- a/xls/dslx/ir_convert/function_converter.h
+++ b/xls/dslx/ir_convert/function_converter.h
@@ -31,6 +31,7 @@
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
+#include "xls/common/attribute_data.h"
#include "xls/dslx/channel_direction.h"
#include "xls/dslx/frontend/ast.h"
#include "xls/dslx/frontend/pos.h"
@@ -693,6 +694,15 @@ class FunctionConverter {
absl::flat_hash_map state_write_called_by_state_name_;
std::vector> proc_def_instances_;
+
+ // If the function has a kFuzzTest attribute, this method will convert the
+ // fuzz test domains to proto and insert it into the AttributeData for
+ // storage in the IR.
+ absl::StatusOr> LowerFuzzTestDomains(
+ Function* node);
+
+ absl::Status LowerDomainExpr(Expr* expr,
+ PackageInterfaceProto::FuzzTestDomain* proto);
};
} // namespace xls::dslx
diff --git a/xls/dslx/ir_convert/function_converter_test.cc b/xls/dslx/ir_convert/function_converter_test.cc
index 72a556102b..7e39a6cc77 100644
--- a/xls/dslx/ir_convert/function_converter_test.cc
+++ b/xls/dslx/ir_convert/function_converter_test.cc
@@ -19,6 +19,8 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
+#include "absl/types/span.h"
+#include "xls/common/attribute_data.h"
#include "xls/common/proto_test_utils.h"
#include "xls/common/status/matchers.h"
#include "xls/dslx/create_import_data.h"
@@ -579,5 +581,352 @@ TEST(FunctionConverterTest, ConvertsFunctionWithUpdate2DBuiltinEmptyTuple) {
)pb"));
}
+TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithBasicDomains) {
+ ImportData import_data = CreateImportDataForTest();
+ XLS_ASSERT_OK_AND_ASSIGN(
+ TypecheckedModule tm,
+ ParseAndTypecheck(R"(
+#[fuzz_test(domains = `u32:0..10, ()`)]
+fn f(x: u32, y: u32) -> u32 { x + y }
+)",
+ "test_module.x", "test_module", &import_data));
+
+ XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
+ tm.module->GetMemberOrError("f"));
+ ASSERT_NE(ft, nullptr);
+
+ Function* f = &ft->fn();
+
+ const ConvertOptions convert_options;
+ PackageConversionData package = MakeConversionData("test_module_package");
+ PackageData package_data{&package};
+ FunctionConverter converter(package_data, tm.module, &import_data,
+ convert_options, /*proc_data=*/nullptr,
+ /*channel_scope=*/nullptr,
+ /*is_top=*/true);
+ XLS_ASSERT_OK(
+ converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));
+
+ ASSERT_FALSE(package_data.conversion_info->package->functions().empty());
+ auto* ir_fn =
+ package_data.conversion_info->package->functions().front().get();
+
+ EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest));
+
+ absl::Span attributes = ir_fn->attributes();
+ ASSERT_EQ(attributes.size(), 1);
+ const AttributeData& attr = attributes[0];
+ EXPECT_EQ(attr.kind(), AttributeKind::kFuzzTest);
+ ASSERT_EQ(attr.args().size(), 1);
+ const AttributeData::Argument& arg = attr.args()[0];
+ EXPECT_THAT(
+ arg,
+ testing::VariantWith(
+ testing::AllOf(
+ testing::Field(&AttributeData::StringKeyValueArgument::first,
+ "domains"),
+ testing::Field(&AttributeData::StringKeyValueArgument::second,
+ testing::HasSubstr("parameter_domains")),
+ testing::Field(&AttributeData::StringKeyValueArgument::second,
+ testing::HasSubstr("arbitrary: true")),
+ testing::Field(&AttributeData::StringKeyValueArgument::second,
+ testing::HasSubstr("range")),
+ testing::Field(&AttributeData::StringKeyValueArgument::second,
+ testing::HasSubstr("bit_count: 32")))));
+}
+
+TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithDifferentBitWidths) {
+ ImportData import_data = CreateImportDataForTest();
+ XLS_ASSERT_OK_AND_ASSIGN(
+ TypecheckedModule tm,
+ ParseAndTypecheck(R"(
+#[fuzz_test(domains = `u8:0..5, u64:0..100`)]
+fn f(x: u8, y: u64) -> u64 { (x as u64) + y }
+)",
+ "test_module.x", "test_module", &import_data));
+
+ XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
+ tm.module->GetMemberOrError("f"));
+ ASSERT_NE(ft, nullptr);
+ Function* f = &ft->fn();
+
+ const ConvertOptions convert_options;
+ PackageConversionData package = MakeConversionData("test_module_package");
+ PackageData package_data{&package};
+ FunctionConverter converter(package_data, tm.module, &import_data,
+ convert_options, /*proc_data=*/nullptr,
+ /*channel_scope=*/nullptr,
+ /*is_top=*/true);
+ XLS_ASSERT_OK(
+ converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));
+
+ ASSERT_FALSE(package_data.conversion_info->package->functions().empty());
+ auto* ir_fn =
+ package_data.conversion_info->package->functions().front().get();
+
+ EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest));
+ absl::Span attributes = ir_fn->attributes();
+ ASSERT_EQ(attributes.size(), 1);
+ const AttributeData::Argument& arg = attributes[0].args()[0];
+ EXPECT_THAT(
+ arg,
+ testing::VariantWith(
+ testing::AllOf(
+ testing::Field(&AttributeData::StringKeyValueArgument::first,
+ "domains"),
+ testing::Field(&AttributeData::StringKeyValueArgument::second,
+ testing::HasSubstr("bit_count: 8")),
+ testing::Field(&AttributeData::StringKeyValueArgument::second,
+ testing::HasSubstr("bit_count: 64")))));
+}
+
+TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithRangeEdgeCases) {
+ ImportData import_data = CreateImportDataForTest();
+ XLS_ASSERT_OK_AND_ASSIGN(
+ TypecheckedModule tm,
+ ParseAndTypecheck(R"(
+#[fuzz_test(domains = `u32:10..10, u32:0..0xFFFFFFFF`)]
+fn f(x: u32, y: u32) -> u32 { x + y }
+)",
+ "test_module.x", "test_module", &import_data));
+
+ XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
+ tm.module->GetMemberOrError("f"));
+ ASSERT_NE(ft, nullptr);
+ Function* f = &ft->fn();
+
+ const ConvertOptions convert_options;
+ PackageConversionData package = MakeConversionData("test_module_package");
+ PackageData package_data{&package};
+ FunctionConverter converter(package_data, tm.module, &import_data,
+ convert_options, /*proc_data=*/nullptr,
+ /*channel_scope=*/nullptr,
+ /*is_top=*/true);
+ XLS_ASSERT_OK(
+ converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));
+
+ ASSERT_FALSE(package_data.conversion_info->package->functions().empty());
+ auto* ir_fn =
+ package_data.conversion_info->package->functions().front().get();
+
+ EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest));
+ absl::Span attributes = ir_fn->attributes();
+ ASSERT_EQ(attributes.size(), 1);
+ const AttributeData::Argument& arg = attributes[0].args()[0];
+ EXPECT_THAT(
+ arg,
+ testing::VariantWith(
+ testing::AllOf(
+ testing::Field(&AttributeData::StringKeyValueArgument::first,
+ "domains"),
+ testing::Field(&AttributeData::StringKeyValueArgument::second,
+ testing::HasSubstr("bit_count: 32")))));
+}
+
+TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithMultipleSameDomains) {
+ ImportData import_data = CreateImportDataForTest();
+ XLS_ASSERT_OK_AND_ASSIGN(
+ TypecheckedModule tm,
+ ParseAndTypecheck(R"(
+#[fuzz_test(domains = `u32:0..10, u32:20..30`)]
+fn f(x: u32, y: u32) -> u32 { x + y }
+)",
+ "test_module.x", "test_module", &import_data));
+
+ XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
+ tm.module->GetMemberOrError("f"));
+ ASSERT_NE(ft, nullptr);
+ Function* f = &ft->fn();
+
+ const ConvertOptions convert_options;
+ PackageConversionData package = MakeConversionData("test_module_package");
+ PackageData package_data{&package};
+ FunctionConverter converter(package_data, tm.module, &import_data,
+ convert_options, /*proc_data=*/nullptr,
+ /*channel_scope=*/nullptr,
+ /*is_top=*/true);
+ XLS_ASSERT_OK(
+ converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));
+
+ ASSERT_FALSE(package_data.conversion_info->package->functions().empty());
+ auto* ir_fn =
+ package_data.conversion_info->package->functions().front().get();
+
+ EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest));
+ absl::Span attributes = ir_fn->attributes();
+ ASSERT_EQ(attributes.size(), 1);
+ const AttributeData::Argument& arg = attributes[0].args()[0];
+ EXPECT_THAT(
+ arg,
+ testing::VariantWith(
+ testing::AllOf(
+ testing::Field(&AttributeData::StringKeyValueArgument::first,
+ "domains"),
+ testing::Field(&AttributeData::StringKeyValueArgument::second,
+ testing::HasSubstr("parameter_domains")))));
+}
+
+TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithTupleDomain) {
+ ImportData import_data = CreateImportDataForTest();
+ XLS_ASSERT_OK_AND_ASSIGN(
+ TypecheckedModule tm,
+ ParseAndTypecheck(R"(
+#[fuzz_test(domains = `(u32:0..10, u32:0..20)`)]
+fn f(x: (u32, u32)) -> u32 { x.0 }
+)",
+ "test_module.x", "test_module", &import_data));
+
+ XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
+ tm.module->GetMemberOrError("f"));
+ ASSERT_NE(ft, nullptr);
+ Function* f = &ft->fn();
+
+ const ConvertOptions convert_options;
+ PackageConversionData package = MakeConversionData("test_module_package");
+ PackageData package_data{.conversion_info = &package};
+ FunctionConverter converter(package_data, tm.module, &import_data,
+ convert_options, /*proc_data=*/nullptr,
+ /*channel_scope=*/nullptr,
+ /*is_top=*/true);
+
+ XLS_ASSERT_OK(
+ converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));
+
+ ASSERT_FALSE(package_data.conversion_info->package->functions().empty());
+ auto* ir_fn =
+ package_data.conversion_info->package->functions().front().get();
+
+ EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest));
+ absl::Span attributes = ir_fn->attributes();
+ const auto& skv =
+ std::get(attributes[0].args()[0]);
+
+ EXPECT_THAT(skv.second, testing::HasSubstr("parameter_domains"));
+ EXPECT_THAT(skv.second, testing::HasSubstr("tuple"));
+ EXPECT_THAT(skv.second, testing::HasSubstr("range"));
+}
+
+TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithElementOfDomain) {
+ ImportData import_data = CreateImportDataForTest();
+ XLS_ASSERT_OK_AND_ASSIGN(
+ TypecheckedModule tm,
+ ParseAndTypecheck(R"(
+#[fuzz_test(domains = `[u32:1, u32:2, u32:3]`)]
+fn f(x: u32) -> u32 { x }
+)",
+ "test_module.x", "test_module", &import_data));
+
+ XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
+ tm.module->GetMemberOrError("f"));
+ ASSERT_NE(ft, nullptr);
+ Function* f = &ft->fn();
+
+ const ConvertOptions convert_options;
+ PackageConversionData package = MakeConversionData("test_module_package");
+ PackageData package_data{&package};
+ FunctionConverter converter(package_data, tm.module, &import_data,
+ convert_options, /*proc_data=*/nullptr,
+ /*channel_scope=*/nullptr,
+ /*is_top=*/true);
+
+ XLS_ASSERT_OK(
+ converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));
+
+ ASSERT_FALSE(package_data.conversion_info->package->functions().empty());
+ auto* ir_fn =
+ package_data.conversion_info->package->functions().front().get();
+
+ EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest));
+ absl::Span attributes = ir_fn->attributes();
+ const auto& skv =
+ std::get(attributes[0].args()[0]);
+
+ EXPECT_THAT(skv.second, testing::HasSubstr("parameter_domains"));
+ EXPECT_THAT(skv.second, testing::HasSubstr("element_of"));
+}
+
+TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithEnumElementOfDomain) {
+ ImportData import_data = CreateImportDataForTest();
+ XLS_ASSERT_OK_AND_ASSIGN(
+ TypecheckedModule tm,
+ ParseAndTypecheck(R"(
+enum MyEnum : u32 {
+ A = 1,
+ B = 2,
+}
+#[fuzz_test(domains = `[MyEnum::A, MyEnum::B]`)]
+fn f(x: MyEnum) -> MyEnum { x }
+)",
+ "test_module.x", "test_module", &import_data));
+
+ XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
+ tm.module->GetMemberOrError("f"));
+ ASSERT_NE(ft, nullptr);
+ Function* f = &ft->fn();
+
+ const ConvertOptions convert_options;
+ PackageConversionData package = MakeConversionData("test_module_package");
+ PackageData package_data{&package};
+ FunctionConverter converter(package_data, tm.module, &import_data,
+ convert_options, /*proc_data=*/nullptr,
+ /*channel_scope=*/nullptr,
+ /*is_top=*/true);
+
+ XLS_ASSERT_OK(
+ converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));
+
+ ASSERT_FALSE(package_data.conversion_info->package->functions().empty());
+ auto* ir_fn =
+ package_data.conversion_info->package->functions().front().get();
+
+ EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest));
+ absl::Span attributes = ir_fn->attributes();
+ const auto& skv =
+ std::get(attributes[0].args()[0]);
+
+ EXPECT_THAT(skv.second, testing::HasSubstr("parameter_domains"));
+ EXPECT_THAT(skv.second, testing::HasSubstr("element_of"));
+}
+
+TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithNestedTupleDomain) {
+ ImportData import_data = CreateImportDataForTest();
+ XLS_ASSERT_OK_AND_ASSIGN(
+ TypecheckedModule tm,
+ ParseAndTypecheck(R"(
+#[fuzz_test(domains = `u32:0..10, ((), u32:0..5)`)]
+fn f(x: u32, y: ((), u32)) -> u32 { x }
+)",
+ "test_module.x", "test_module", &import_data));
+
+ XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft,
+ tm.module->GetMemberOrError("f"));
+ ASSERT_NE(ft, nullptr);
+ Function* f = &ft->fn();
+
+ const ConvertOptions convert_options;
+ PackageConversionData package = MakeConversionData("test_module_package");
+ PackageData package_data{&package};
+ FunctionConverter converter(package_data, tm.module, &import_data,
+ convert_options, /*proc_data=*/nullptr,
+ /*channel_scope=*/nullptr,
+ /*is_top=*/true);
+
+ XLS_ASSERT_OK(
+ converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr));
+
+ ASSERT_FALSE(package_data.conversion_info->package->functions().empty());
+ auto* ir_fn =
+ package_data.conversion_info->package->functions().front().get();
+
+ EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest));
+ absl::Span attributes = ir_fn->attributes();
+ const auto& skv =
+ std::get(attributes[0].args()[0]);
+
+ EXPECT_THAT(skv.second, testing::HasSubstr("parameter_domains"));
+ EXPECT_THAT(skv.second, testing::HasSubstr("tuple"));
+ EXPECT_THAT(skv.second, testing::HasSubstr("arbitrary: true"));
+}
+
} // namespace
} // namespace xls::dslx
diff --git a/xls/dslx/ir_convert/ir_converter.cc b/xls/dslx/ir_convert/ir_converter.cc
index 130198636e..3ef73168c1 100644
--- a/xls/dslx/ir_convert/ir_converter.cc
+++ b/xls/dslx/ir_convert/ir_converter.cc
@@ -522,6 +522,12 @@ absl::Status ConvertOneFunctionIntoPackage(Module* module,
return ConvertOneFunctionIntoPackageInternal(&(*test_fn)->fn(), import_data,
options, conv);
}
+ absl::StatusOr fuzz_test_fn =
+ module->GetMemberOrError(entry_function_name);
+ if (fuzz_test_fn.ok()) {
+ return ConvertOneFunctionIntoPackageInternal(&(*fuzz_test_fn)->fn(),
+ import_data, options, conv);
+ }
std::optional fn_or = module->GetFunction(entry_function_name);
if (fn_or.has_value()) {
diff --git a/xls/ir/BUILD b/xls/ir/BUILD
index 3acf6fea70..2f9ac7b4da 100644
--- a/xls/ir/BUILD
+++ b/xls/ir/BUILD
@@ -895,6 +895,7 @@ cc_library(
":verifier",
"//xls/codegen:module_signature",
"//xls/codegen:module_signature_cc_proto",
+ "//xls/common:attribute_data",
"//xls/common:visitor",
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
@@ -928,6 +929,7 @@ cc_test(
":state_element",
":type",
":value",
+ "//xls/common:attribute_data",
"//xls/common:source_location",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
@@ -1570,7 +1572,10 @@ py_proto_library(
proto_library(
name = "xls_ir_interface_proto",
srcs = ["xls_ir_interface.proto"],
- deps = [":xls_type_proto"],
+ deps = [
+ ":xls_type_proto",
+ ":xls_value_proto",
+ ],
)
cc_proto_library(
diff --git a/xls/ir/function.cc b/xls/ir/function.cc
index c99e0e4389..45157be200 100644
--- a/xls/ir/function.cc
+++ b/xls/ir/function.cc
@@ -32,9 +32,11 @@
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "absl/types/span.h"
+#include "absl/types/variant.h"
#include "xls/common/attribute_data.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
+#include "xls/common/visitor.h"
#include "xls/ir/change_listener.h"
#include "xls/ir/function_base.h"
#include "xls/ir/ir_annotator.h"
@@ -348,8 +350,32 @@ std::vector Function::AttributeIrStrings() const {
attribute_strings.push_back("non_synth");
}
for (const auto& attr : attributes_) {
- // TODO - davidplass: Properly serialize AttributeData arguments.
- attribute_strings.push_back(AttributeKindToString(attr.kind()));
+ std::string attr_str = AttributeKindToString(attr.kind());
+ if (!attr.args().empty()) {
+ std::vector arg_strings;
+ for (const auto& arg : attr.args()) {
+ std::string arg_str = absl::visit(
+ Visitor{
+ [](const std::string& s) { return s; },
+ [](const AttributeData::StringLiteralArgument& s) {
+ return absl::StrCat("\"", s.text, "\"");
+ },
+ [](const AttributeData::StringKeyValueArgument& s) {
+ if (s.is_backticked) {
+ return absl::StrCat(s.first, "=`", s.second, "`");
+ }
+ return absl::StrCat(s.first, "=\"", s.second, "\"");
+ },
+ [](const AttributeData::IntKeyValueArgument& s) {
+ return absl::StrCat(s.first, "=", s.second);
+ },
+ },
+ arg);
+ arg_strings.push_back(arg_str);
+ }
+ absl::StrAppend(&attr_str, "(", absl::StrJoin(arg_strings, ", "), ")");
+ }
+ attribute_strings.push_back(attr_str);
}
return attribute_strings;
}
diff --git a/xls/ir/function_test.cc b/xls/ir/function_test.cc
index 1f11e09183..1615c2ad8b 100644
--- a/xls/ir/function_test.cc
+++ b/xls/ir/function_test.cc
@@ -642,5 +642,26 @@ TEST_F(FunctionTest, AttributeDataTest) {
EXPECT_THAT(f->DumpIr(), HasSubstr("#[fuzz_test]"));
}
+TEST_F(FunctionTest, AttributeSerializationWithArgumentsTest) {
+ auto p = CreatePackage();
+ FunctionBuilder b("f", p.get());
+ b.Param("x", p->GetBitsType(32));
+ XLS_ASSERT_OK_AND_ASSIGN(Function * f, b.BuildWithReturnValue(b.Tuple({})));
+
+ std::vector args;
+ args.push_back("ident");
+ args.push_back(AttributeData::StringLiteralArgument{.text = "literal"});
+ args.push_back(AttributeData::StringKeyValueArgument{
+ .first = "key", .second = "backticked", .is_backticked = true});
+ args.push_back(AttributeData::IntKeyValueArgument{"int_key", 42});
+
+ f->AddAttribute(AttributeData(AttributeKind::kFuzzTest, args));
+
+ EXPECT_THAT(
+ f->DumpIr(),
+ HasSubstr(
+ "#[fuzz_test(ident, \"literal\", key=`backticked`, int_key=42)]"));
+}
+
} // namespace
} // namespace xls
diff --git a/xls/ir/ir_parser.cc b/xls/ir/ir_parser.cc
index 72b7dbcaf1..7566f1436e 100644
--- a/xls/ir/ir_parser.cc
+++ b/xls/ir/ir_parser.cc
@@ -44,6 +44,7 @@
#include "google/protobuf/text_format.h"
#include "xls/codegen/module_signature.h"
#include "xls/codegen/module_signature.pb.h"
+#include "xls/common/attribute_data.h"
#include "xls/common/status/ret_check.h"
#include "xls/common/status/status_macros.h"
#include "xls/common/visitor.h"
@@ -2386,21 +2387,39 @@ absl::StatusOr Parser::ParseFunctionInternal(
}
for (const IrAttribute& attribute : outer_attributes) {
- if (std::holds_alternative(attribute.payload)) {
- result->SetInitiationInterval(
- std::get(attribute.payload).value);
- } else if (std::holds_alternative(attribute.payload)) {
- // Dummy parse to make sure it is a valid template.
- const ForeignFunctionData& ffi =
- std::get(attribute.payload);
- XLS_RETURN_IF_ERROR(CodeTemplate::Create(ffi.code_template()).status());
- result->SetForeignFunctionData(ffi);
- } else if (std::holds_alternative(attribute.payload)) {
- result->set_non_synth(true);
- } else {
- return absl::InvalidArgumentError(absl::StrFormat(
- "Invalid attribute for function: %s", attribute.name));
- }
+ XLS_RETURN_IF_ERROR(absl::visit(
+ Visitor{[&](const InitiationInterval& ii) -> absl::Status {
+ result->SetInitiationInterval(ii.value);
+ return absl::OkStatus();
+ },
+ [&](const ForeignFunctionData& ffi) -> absl::Status {
+ // Dummy parse to make sure it is a valid template.
+ XLS_RETURN_IF_ERROR(
+ CodeTemplate::Create(ffi.code_template()).status());
+ result->SetForeignFunctionData(ffi);
+ return absl::OkStatus();
+ },
+ [&](const NonSynthMarker&) -> absl::Status {
+ result->set_non_synth(true);
+ return absl::OkStatus();
+ },
+ [&](const FuzzTestAttribute& fuzz_test_attr) -> absl::Status {
+ std::vector args;
+ if (fuzz_test_attr.domains_proto_str.has_value()) {
+ args.push_back(AttributeData::StringKeyValueArgument{
+ .first = "domains",
+ .second = *fuzz_test_attr.domains_proto_str,
+ .is_backticked = true});
+ }
+ result->AddAttribute(
+ AttributeData(AttributeKind::kFuzzTest, std::move(args)));
+ return absl::OkStatus();
+ },
+ [&](const auto&) -> absl::Status {
+ return absl::InvalidArgumentError(absl::StrFormat(
+ "Invalid attribute for function: %s", attribute.name));
+ }},
+ attribute.payload));
}
return result;
@@ -3020,6 +3039,28 @@ absl::StatusOr Parser::ParseAttribute(Package* package) {
scanner_.DropTokenOrError(LexicalTokenType::kParenClose));
return IrAttribute{.name = attribute_name.value(), .payload = ffi};
}
+ if (attribute_name.value() == "fuzz_test") {
+ std::optional domains_proto_str;
+
+ if (scanner_.TryDropToken(LexicalTokenType::kParenOpen)) {
+ absl::flat_hash_map> handlers;
+ handlers["domains"] = [&]() -> absl::Status {
+ XLS_ASSIGN_OR_RETURN(
+ Token token,
+ scanner_.PopTokenOrError(LexicalTokenType::kBacktickedString,
+ "backticked string"));
+ domains_proto_str = token.value();
+ return absl::OkStatus();
+ };
+ XLS_RETURN_IF_ERROR(ParseKeywordArguments(handlers));
+ XLS_RETURN_IF_ERROR(
+ scanner_.DropTokenOrError(LexicalTokenType::kParenClose));
+ }
+
+ return IrAttribute{
+ .name = attribute_name.value(),
+ .payload = FuzzTestAttribute{.domains_proto_str = domains_proto_str}};
+ }
if (attribute_name.value() == "channel_ports") {
std::optional channel_name;
std::optional type;
diff --git a/xls/ir/ir_parser.h b/xls/ir/ir_parser.h
index 4e7ab138bf..2cd21afeba 100644
--- a/xls/ir/ir_parser.h
+++ b/xls/ir/ir_parser.h
@@ -74,10 +74,14 @@ struct ResetAttribute {
struct NonSynthMarker : public std::monostate {};
+struct FuzzTestAttribute {
+ std::optional domains_proto_str;
+};
+
using IrAttributePayload =
std::variant;
+ NonSynthMarker, FuzzTestAttribute>;
struct IrAttribute {
std::string name;
IrAttributePayload payload;
diff --git a/xls/ir/ir_parser_test.cc b/xls/ir/ir_parser_test.cc
index 9049b833df..5a1e7ec2c8 100644
--- a/xls/ir/ir_parser_test.cc
+++ b/xls/ir/ir_parser_test.cc
@@ -19,6 +19,7 @@
#include
#include
#include
+#include
#include "gmock/gmock.h"
#include "gtest/gtest.h"
@@ -28,6 +29,7 @@
#include "absl/strings/str_cat.h"
#include "absl/strings/substitute.h"
#include "absl/types/span.h"
+#include "xls/common/attribute_data.h"
#include "xls/common/source_location.h"
#include "xls/common/status/matchers.h"
#include "xls/ir/bits.h"
@@ -1206,4 +1208,94 @@ scheduled_block b() {
EXPECT_EQ(block->stages().size(), 1);
}
+TEST(IrParserTest, ParseAttributeWithArguments) {
+ std::string program = R"(
+package test
+
+#[fuzz_test(domains = `u32:0..100, ()`)]
+fn main(x: bits[32]) -> bits[32] {
+ ret x: bits[32] = param(name=x)
+}
+)";
+ XLS_ASSERT_OK_AND_ASSIGN(auto package, Parser::ParsePackage(program));
+ EXPECT_EQ(package->functions().size(), 1);
+ Function* func = package->functions().front().get();
+ EXPECT_TRUE(func->HasAttribute(AttributeKind::kFuzzTest));
+ absl::Span attributes = func->attributes();
+ ASSERT_EQ(attributes.size(), 1);
+ const AttributeData& attr = attributes[0];
+ EXPECT_EQ(attr.kind(), AttributeKind::kFuzzTest);
+ ASSERT_EQ(attr.args().size(), 1);
+ const AttributeData::Argument& arg = attr.args()[0];
+ ASSERT_TRUE(
+ std::holds_alternative(arg));
+ const auto& skv = std::get(arg);
+ EXPECT_EQ(skv.first, "domains");
+ EXPECT_EQ(skv.second, "u32:0..100, ()");
+ EXPECT_TRUE(skv.is_backticked);
+}
+
+TEST(IrParserTest, ParseAttributeWithBrokenBacktick) {
+ std::string program = R"(
+package test
+
+#[fuzz_test(domains = `u32:0..100, ())]
+fn main(x: bits[32]) -> bits[32] {
+ ret x: bits[32] = param(name=x)
+}
+)";
+ EXPECT_THAT(Parser::ParsePackage(program).status(),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("Unterminated quoted string")));
+}
+
+TEST(IrParserTest, ParseAttributeEmptyFuzzTest) {
+ std::string program = R"ir(
+package test
+
+#[fuzz_test]
+fn main(x: bits[32]) -> bits[32] {
+ ret x: bits[32] = param(name=x)
+}
+)ir";
+ XLS_ASSERT_OK_AND_ASSIGN(auto package, Parser::ParsePackage(program));
+ EXPECT_EQ(package->functions().size(), 1);
+ Function* func = package->functions().front().get();
+ EXPECT_TRUE(func->HasAttribute(AttributeKind::kFuzzTest));
+ absl::Span attributes = func->attributes();
+ ASSERT_EQ(attributes.size(), 1);
+ const AttributeData& attr = attributes[0];
+ EXPECT_EQ(attr.kind(), AttributeKind::kFuzzTest);
+ EXPECT_TRUE(attr.args().empty());
+}
+
+TEST(IrParserTest, ParseAttributeExtraArgument) {
+ std::string program = R"ir(
+package test
+
+#[fuzz_test(domains = `u32:0..100, ()`, extra = 42)]
+fn main(x: bits[32]) -> bits[32] {
+ ret x: bits[32] = param(name=x)
+}
+)ir";
+ EXPECT_THAT(Parser::ParsePackage(program).status(),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("Invalid keyword argument")));
+}
+
+TEST(IrParserTest, ParseAttributeInvalidDomainType) {
+ std::string program = R"ir(
+package test
+
+#[fuzz_test(domains = "u32:0..100, ()")]
+fn main(x: bits[32]) -> bits[32] {
+ ret x: bits[32] = param(name=x)
+}
+)ir";
+ EXPECT_THAT(
+ Parser::ParsePackage(program).status(),
+ StatusIs(absl::StatusCode::kInvalidArgument,
+ HasSubstr("Expected token of type \"backticked string\"")));
+}
+
} // namespace xls
diff --git a/xls/ir/ir_scanner.cc b/xls/ir/ir_scanner.cc
index 142f1c17c1..48055705bb 100644
--- a/xls/ir/ir_scanner.cc
+++ b/xls/ir/ir_scanner.cc
@@ -74,6 +74,8 @@ std::string LexicalTokenTypeToString(LexicalTokenType token_type) {
return "(";
case LexicalTokenType::kQuotedString:
return "quoted string";
+ case LexicalTokenType::kBacktickedString:
+ return "backticked string";
case LexicalTokenType::kRightArrow:
return "->";
case LexicalTokenType::kHash:
@@ -304,6 +306,13 @@ class Tokenizer {
start_lineno, start_colno));
continue;
}
+ XLS_ASSIGN_OR_RETURN(content,
+ MatchQuotedString("`", /*allow_multiline=*/false));
+ if (content.has_value()) {
+ tokens.push_back(Token(LexicalTokenType::kBacktickedString,
+ content.value(), start_lineno, start_colno));
+ continue;
+ }
// Handle single-character tokens.
LexicalTokenType token_type;
diff --git a/xls/ir/ir_scanner.h b/xls/ir/ir_scanner.h
index 12ca341a1b..81b13e740f 100644
--- a/xls/ir/ir_scanner.h
+++ b/xls/ir/ir_scanner.h
@@ -53,6 +53,7 @@ enum class LexicalTokenType {
kRightArrow,
kHash,
kBang,
+ kBacktickedString,
};
std::string LexicalTokenTypeToString(LexicalTokenType token_type);
diff --git a/xls/ir/xls_ir_interface.proto b/xls/ir/xls_ir_interface.proto
index 34e734a476..91e5be5d5a 100644
--- a/xls/ir/xls_ir_interface.proto
+++ b/xls/ir/xls_ir_interface.proto
@@ -17,6 +17,7 @@ syntax = "proto3";
package xls;
import "xls/ir/xls_type.proto";
+import "xls/ir/xls_value.proto";
message PackageInterfaceProto {
// A generic thing with a name and a type.
@@ -59,6 +60,10 @@ message PackageInterfaceProto {
optional TypeProto result_type = 3;
// If present the corresponding sv type for the result of this function.
optional string sv_result_type = 4;
+
+ // The structured fuzzing domains for this function. There may be
+ // either zero domains, or exactly one domain per parameter.
+ repeated FuzzTestDomain parameter_domains = 5;
}
message Proc {
@@ -87,6 +92,31 @@ message PackageInterfaceProto {
repeated NamedValue output_ports = 4;
}
+ message FuzzTestDomain {
+ message Range {
+ optional ValueProto min = 1;
+ optional ValueProto max = 2;
+ }
+ message ElementOf {
+ repeated ValueProto values = 1;
+ }
+ message Tuple {
+ repeated FuzzTestDomain elements = 1;
+ }
+ message Array {
+ optional FuzzTestDomain element_domain = 1;
+ optional int64 size = 2;
+ }
+
+ oneof domain_kind {
+ bool arbitrary = 1;
+ Range range = 2;
+ ElementOf element_of = 3;
+ Tuple tuple = 4;
+ Array array = 5;
+ }
+ }
+
// Name of the overall package.
optional string name = 1;
diff --git a/xls/jit/BUILD b/xls/jit/BUILD
index 8289daee73..7efe702dda 100644
--- a/xls/jit/BUILD
+++ b/xls/jit/BUILD
@@ -411,6 +411,7 @@ pytype_strict_contrib_test(
deps = [
":jit_wrapper_generator",
"//xls/common:runfiles",
+ "//xls/ir:xls_ir_interface_py_pb2",
"//xls/ir:xls_type_py_pb2",
"@abseil-py//absl/testing:absltest",
"@xls_pip_deps//jinja2",
diff --git a/xls/jit/fuzztest_cc.tmpl b/xls/jit/fuzztest_cc.tmpl
index 14061bf52e..546af21a8c 100644
--- a/xls/jit/fuzztest_cc.tmpl
+++ b/xls/jit/fuzztest_cc.tmpl
@@ -9,14 +9,29 @@
namespace {{ fuzztest.namespace }} {
namespace {
-void {{fuzztest.property_function_name}}({{ fuzztest.params | map("property_param") | join (", ") }}) {
+void {{fuzztest.property_function_name}}({% for param in fuzztest.params %}{{param.cpp_type}} {{param.name}}{% if not loop.last %}, {% endif %}{% endfor %}) {
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<{{fuzztest.lib_class_name}}> f,
{{fuzztest.lib_class_name}}::Create());
+ {% if fuzztest.can_be_specialized %}
+ // Use specialized overload.
+ XLS_ASSERT_OK_AND_ASSIGN(auto res_native, f->Run(
+ {% for param in fuzztest.params %}{{param.name}}{% if not loop.last %}, {% endif %}{% endfor %}
+ ));
+ xls::Value result = xls::Value(xls::UBits(res_native, {{fuzztest.result_width}}));
+ {% else %}
+ // Fallback: Manual binding for the standard Run(xls::Value...) overload.
+ {% for param in fuzztest.params %}
+ {% if param.is_native %}
+ xls::Value {{param.name}}_val = {{param.conversion_snippet}};
+ {% endif %}
+ {% endfor %}
XLS_ASSERT_OK_AND_ASSIGN(xls::Value result, f->Run(
{% for param in fuzztest.params %}
- {{param.name}}{% if not loop.last %}, {% endif %}
- {% endfor %}));
+ {{param.name}}{% if param.is_native %}_val{% endif %}{% if not loop.last %}, {% endif %}
+ {% endfor %}
+ ));
+ {% endif %}
{% if fuzztest.return_type %}
ASSERT_TRUE(result.IsAllOnes());
@@ -26,7 +41,10 @@ void {{fuzztest.property_function_name}}({{ fuzztest.params | map("property_para
FUZZ_TEST({{fuzztest.fuzztest_name}}, {{fuzztest.property_function_name}})
.WithDomains(
{% for param in fuzztest.params %}
- xls::ArbitraryValue({{fuzztest.lib_class_name}}::GetParamType({{param.index}}).value()){% if not loop.last %},
+ {% if param.domain_snippet %}
+ {{param.domain_snippet}}{% if not loop.last %}, {% endif %}
+ {% else %}
+ xls::ArbitraryValue({{fuzztest.lib_class_name}}::GetParamType({{param.index}}).value()){% if not loop.last %}, {% endif %}
{% endif %}
{% endfor %}
);
diff --git a/xls/jit/jit_wrapper_generator.py b/xls/jit/jit_wrapper_generator.py
index 49759cc165..c138a15354 100644
--- a/xls/jit/jit_wrapper_generator.py
+++ b/xls/jit/jit_wrapper_generator.py
@@ -27,6 +27,16 @@
from xls.jit import aot_entrypoint_pb2
[email protected](frozen=True)
+class FuzzTestInfo:
+ """FuzzTest specific information for a value."""
+
+ domain_snippet: Optional[str] = None
+ domain_proto: Optional[
+ ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain
+ ] = None
+
+
@dataclasses.dataclass(frozen=True)
class XlsNamedValue:
"""A Named & typed value for the wrapped function/proc."""
@@ -36,6 +46,7 @@ class XlsNamedValue:
unpacked_type: str
specialized_type: Optional[str]
type_proto: type_pb2.TypeProto
+ fuzztest_info: Optional[FuzzTestInfo] = None
@property
def value_arg(self):
@@ -123,6 +134,10 @@ class PropertyFunctionParam:
name: str
# The index of the parameter in the function signature.
index: int
+ domain_snippet: Optional[str] = None
+ cpp_type: str = "xls::Value"
+ is_native: bool = False
+ conversion_snippet: Optional[str] = None
@dataclasses.dataclass(frozen=True)
@@ -137,6 +152,8 @@ class PropertyFunction:
# Function params and result.
params: Sequence[PropertyFunctionParam]
return_type: bool
+ can_be_specialized: bool = False
+ result_width: int = 0
def to_packed(t: type_pb2.TypeProto) -> str:
@@ -263,6 +280,110 @@ def to_specialized(
return None
+def extract_int_from_bytes(data: bytes) -> int:
+ return int.from_bytes(data, byteorder="little")
+
+
+# Returns true if the range domain fits within a 64-bit unsigned integer.
+def can_use_uint64_range(
+ t: type_pb2.TypeProto,
+ d: Optional[ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain],
+) -> bool:
+ if t.type_enum == type_pb2.TypeProto.BITS and d and d.HasField("range"):
+ max_val = extract_int_from_bytes(d.range.max.bits.data)
+ return max_val < (1 << 64)
+ return False
+
+
+def to_domain(
+ t: type_pb2.TypeProto,
+ d: Optional[ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain],
+) -> Optional[str]:
+ """Converts an XLS type and domain spec to a FuzzTest domain string.
+
+ Args:
+ t: The XLS type proto.
+ d: The optional FuzzTest domain specification from the package interface.
+
+ Returns:
+ A string representing the C++ FuzzTest domain (e.g.,
+ "fuzztest::Arbitrary()"), or None if it should fallback to
+ xls::ArbitraryValue.
+
+ Raises:
+ app.UsageError: If the domain specification is invalid or unsupported for
+ the given type.
+ """
+ if (
+ d is None
+ or d.HasField("arbitrary")
+ or (
+ not d.HasField("range")
+ and not d.HasField("element_of")
+ and not d.HasField("tuple")
+ )
+ ):
+ if t.type_enum == type_pb2.TypeProto.BITS:
+ c_type = to_specialized(t, int_only=True)
+ if c_type is None:
+ return None
+ if t.bit_count in (8, 16, 32, 64):
+ return f"fuzztest::Arbitrary<{c_type}>()"
+ else:
+ max_val = (1 << t.bit_count) - 1
+ return f"fuzztest::InRange<{c_type}>(0, {max_val})"
+ elif t.type_enum == type_pb2.TypeProto.TUPLE:
+ elems = [to_domain(e, None) for e in t.tuple_elements]
+ if any(e is None for e in elems):
+ return None
+ return f"fuzztest::TupleOf({', '.join(elems)})"
+ elif t.type_enum == type_pb2.TypeProto.ARRAY:
+ elem_domain = to_domain(t.array_element, None)
+ if elem_domain is None:
+ return None
+ return f"fuzztest::VectorOf({elem_domain}).WithSize({t.array_size})"
+ else:
+ return None
+
+ if d.HasField("range"):
+ c_type = to_specialized(t, int_only=True)
+ if c_type is None:
+ if can_use_uint64_range(t, d):
+ c_type = "uint64_t"
+ else:
+ raise app.UsageError(
+ "Range domain only supported for specializable bits types or ranges"
+ " fitting in 64 bits"
+ )
+ min_val = extract_int_from_bytes(d.range.min.bits.data)
+ max_val = extract_int_from_bytes(d.range.max.bits.data)
+ return f"fuzztest::InRange<{c_type}>({min_val}, {max_val})"
+
+ if d.HasField("element_of"):
+ c_type = to_specialized(t, int_only=True)
+ if c_type is None:
+ raise app.UsageError(
+ "ElementOf domain only supported for specializable bits types in"
+ " this CL"
+ )
+ vals = []
+ for v in d.element_of.values:
+ vals.append(str(extract_int_from_bytes(v.bits.data)))
+ return f"fuzztest::ElementOf(std::vector<{c_type}>{{{', '.join(vals)}}})"
+
+ if d.HasField("tuple"):
+ if t.type_enum != type_pb2.TypeProto.TUPLE:
+ raise app.UsageError("Tuple domain requires Tuple type")
+ if len(d.tuple.elements) != len(t.tuple_elements):
+ raise app.UsageError("Tuple domain and type element count mismatch")
+ elems = [
+ to_domain(te, de) for te, de in zip(t.tuple_elements, d.tuple.elements)
+ ]
+ return f"fuzztest::TupleOf({', '.join(elems)})"
+
+ raise app.UsageError(f"Unsupported domain: {d}")
+
+
def to_chan(
c: ir_interface_pb2.PackageInterfaceProto.Channel, package_name: str
) -> XlsChannel:
@@ -278,13 +399,21 @@ def to_chan(
def to_param(
p: ir_interface_pb2.PackageInterfaceProto.NamedValue,
+ d: Optional[ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain] = None,
) -> XlsNamedValue:
+ fuzztest_info = None
+ if d is not None:
+ fuzztest_info = FuzzTestInfo(
+ domain_snippet=to_domain(p.type, d), domain_proto=d
+ )
+
return XlsNamedValue(
name=p.name,
packed_type=to_packed(p.type),
unpacked_type=to_unpacked(p.type),
specialized_type=to_specialized(p.type),
type_proto=p.type,
+ fuzztest_info=fuzztest_info,
)
@@ -330,7 +459,12 @@ def interpret_function_interface(
"""
if func_ir.base.name != aot_info.entrypoint[0].xls_function_identifier:
raise app.UsageError("Aot info is for a different function.")
- params = [to_param(p) for p in func_ir.parameters]
+ params = []
+ for idx, p in enumerate(func_ir.parameters):
+ d = None
+ if idx < len(func_ir.parameter_domains):
+ d = func_ir.parameter_domains[idx]
+ params.append(to_param(p, d))
result = XlsNamedValue(
name="result",
packed_type=to_packed(func_ir.result_type),
@@ -458,16 +592,52 @@ def wrapped_to_fuzztest(
params = []
if wrapped.params:
for idx, p in enumerate(wrapped.params):
- params.append(PropertyFunctionParam(name=p.name, index=idx))
+ is_native = p.specialized_type is not None
+ cpp_type = p.specialized_type or "xls::Value"
+ conversion_snippet = None
+
+ domain_proto = p.fuzztest_info.domain_proto if p.fuzztest_info else None
+ domain_snippet = (
+ p.fuzztest_info.domain_snippet if p.fuzztest_info else None
+ )
+
+ if can_use_uint64_range(p.type_proto, domain_proto):
+ cpp_type = "uint64_t"
+ is_native = True
+
+ if is_native and p.type_proto.type_enum == type_pb2.TypeProto.BITS:
+ conversion_snippet = (
+ f"xls::Value(xls::UBits({p.name}, {p.type_proto.bit_count}))"
+ )
+
+ params.append(
+ PropertyFunctionParam(
+ name=p.name,
+ index=idx,
+ domain_snippet=domain_snippet,
+ cpp_type=cpp_type,
+ is_native=is_native,
+ conversion_snippet=conversion_snippet,
+ )
+ )
+
+ result_width = 0
+ if (
+ wrapped.result
+ and wrapped.result.type_proto.type_enum == type_pb2.TypeProto.BITS
+ ):
+ result_width = wrapped.result.type_proto.bit_count
+
return PropertyFunction(
fuzztest_name=wrapped.function_name + "_fuzztest",
property_function_name=wrapped.function_name,
lib_class_name=lib_class_name,
lib_header_path=lib_header_path,
- # Everything shares the namespace
namespace=wrapped.namespace,
params=params,
return_type=wrapped.result is not None,
+ can_be_specialized=wrapped.can_be_specialized,
+ result_width=result_width,
)
@@ -479,7 +649,11 @@ def render_fuzztest(
lib_header_path: str,
) -> str:
"""Renders the fuzztest C++ code."""
- env.filters["property_param"] = lambda p: "xls::Value " + p.name
+ env.filters["property_param"] = lambda p: (
+ f"const {p.c_type}& {p.name}"
+ if p.c_type == "xls::Value"
+ else f"{p.c_type} {p.name}"
+ )
cc_template = env.from_string(cc_template_content)
fuzztest = wrapped_to_fuzztest(wrapped, lib_class_name, lib_header_path)
bindings = {"fuzztest": fuzztest, "len": len}
diff --git a/xls/jit/jit_wrapper_generator_test.py b/xls/jit/jit_wrapper_generator_test.py
index 9a025c5019..baa3e0163f 100644
--- a/xls/jit/jit_wrapper_generator_test.py
+++ b/xls/jit/jit_wrapper_generator_test.py
@@ -16,6 +16,7 @@
from absl.testing import absltest
from xls.common import runfiles
+from xls.ir import xls_ir_interface_pb2 as ir_interface_pb2
from xls.ir import xls_type_pb2 as type_pb2
from xls.jit import jit_wrapper_generator
@@ -576,5 +577,146 @@ def test_render_fuzztest_tuple_mixed(self):
)
+class JitWrapperGeneratorToDomainTest(absltest.TestCase):
+
+ def test_extract_int_from_bytes(self):
+ self.assertEqual(jit_wrapper_generator.extract_int_from_bytes(b'\x00'), 0)
+ self.assertEqual(jit_wrapper_generator.extract_int_from_bytes(b'\x0a'), 10)
+ self.assertEqual(
+ jit_wrapper_generator.extract_int_from_bytes(b'\xff\xff\xff\xff'),
+ 0xFFFFFFFF,
+ )
+
+ def test_bits_domain_power_of_2(self):
+ u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
+ self.assertEqual(
+ jit_wrapper_generator.to_domain(u32, None),
+ 'fuzztest::Arbitrary()',
+ )
+
+ def test_bits_domain_non_power_of_2(self):
+ u17 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=17)
+ self.assertEqual(
+ jit_wrapper_generator.to_domain(u17, None),
+ 'fuzztest::InRange(0, 131071)',
+ )
+
+ def test_bits_domain_too_wide(self):
+ u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
+ self.assertIsNone(jit_wrapper_generator.to_domain(u128, None))
+
+ def test_range_domain(self):
+ u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
+ d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
+ d.range.min.bits.bit_count = 32
+ d.range.min.bits.data = b'\x01'
+ d.range.max.bits.bit_count = 32
+ d.range.max.bits.data = b'\x0a'
+ self.assertEqual(
+ jit_wrapper_generator.to_domain(u32, d),
+ 'fuzztest::InRange(1, 10)',
+ )
+
+ def test_range_domain_wide_bits_fits(self):
+ u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128)
+ d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
+ d.range.min.bits.bit_count = 128
+ d.range.min.bits.data = b'\x01'
+ d.range.max.bits.bit_count = 128
+ d.range.max.bits.data = b'\x0a'
+ self.assertEqual(
+ jit_wrapper_generator.to_domain(u128, d),
+ 'fuzztest::InRange(1, 10)',
+ )
+
+ def test_element_of_domain(self):
+ u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
+ d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
+ v1 = d.element_of.values.add()
+ v1.bits.bit_count = 32
+ v1.bits.data = b'\x01'
+ v2 = d.element_of.values.add()
+ v2.bits.bit_count = 32
+ v2.bits.data = b'\x02'
+ self.assertEqual(
+ jit_wrapper_generator.to_domain(u32, d),
+ 'fuzztest::ElementOf(std::vector{1, 2})',
+ )
+
+ def test_tuple_domain(self):
+ u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
+ tup = type_pb2.TypeProto(
+ type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, u32]
+ )
+ d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
+ d.tuple.elements.add().range.min.bits.bit_count = 32
+ d.tuple.elements[0].range.min.bits.data = b'\x00'
+ d.tuple.elements[0].range.max.bits.bit_count = 32
+ d.tuple.elements[0].range.max.bits.data = b'\x0a'
+ d.tuple.elements.add().arbitrary = True
+
+ self.assertEqual(
+ jit_wrapper_generator.to_domain(tup, d),
+ 'fuzztest::TupleOf(fuzztest::InRange(0, 10),'
+ ' fuzztest::Arbitrary())',
+ )
+
+ def test_nested_tuple_domain(self):
+ u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
+ inner_tup = type_pb2.TypeProto(
+ type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32]
+ )
+ outer_tup = type_pb2.TypeProto(
+ type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, inner_tup]
+ )
+
+ d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
+ d.tuple.elements.add().arbitrary = True
+ inner_d = d.tuple.elements.add().tuple.elements.add()
+ inner_d.range.min.bits.bit_count = 32
+ inner_d.range.min.bits.data = b'\x00'
+ inner_d.range.max.bits.bit_count = 32
+ inner_d.range.max.bits.data = b'\x05'
+
+ self.assertEqual(
+ jit_wrapper_generator.to_domain(outer_tup, d),
+ 'fuzztest::TupleOf(fuzztest::Arbitrary(),'
+ ' fuzztest::TupleOf(fuzztest::InRange(0, 5)))',
+ )
+
+ def test_tuple_with_array_domain(self):
+ u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
+ arr = type_pb2.TypeProto(
+ type_enum=type_pb2.TypeProto.ARRAY, array_size=3, array_element=u32
+ )
+ tup = type_pb2.TypeProto(
+ type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, arr]
+ )
+
+ d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
+ d.tuple.elements.add().arbitrary = True
+ d.tuple.elements.add().arbitrary = True
+
+ self.assertEqual(
+ jit_wrapper_generator.to_domain(tup, d),
+ 'fuzztest::TupleOf(fuzztest::Arbitrary(),'
+ ' fuzztest::VectorOf(fuzztest::Arbitrary()).WithSize(3))',
+ )
+
+ def test_unsupported_domain_raises(self):
+ u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32)
+ tup = type_pb2.TypeProto(
+ type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32]
+ )
+ d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain()
+ d.range.min.bits.bit_count = 32
+ d.range.min.bits.data = b'\x00'
+ d.range.max.bits.bit_count = 32
+ d.range.max.bits.data = b'\x0a'
+
+ with self.assertRaises(Exception):
+ jit_wrapper_generator.to_domain(tup, d)
+
+
if __name__ == '__main__':
absltest.main()