From 9da87443a5e2835bb8a0e5527105a8438e713360 Mon Sep 17 00:00:00 2001 From: David Plass Date: Fri, 24 Apr 2026 12:12:21 -0700 Subject: [PATCH] [DSLX Fuzz Testing] Pass parameters (of type xls::Value) by const reference to the test function, in generated fuzz tests. PiperOrigin-RevId: 905158594 --- xls/build_rules/tests/fuzz_test_example.x | 1 + xls/dev_tools/BUILD | 3 + xls/dev_tools/extract_interface.cc | 32 ++ xls/dev_tools/extract_interface_test.cc | 100 +++++ xls/dslx/ir_convert/BUILD | 5 + xls/dslx/ir_convert/function_converter.cc | 98 +++++ xls/dslx/ir_convert/function_converter.h | 10 + .../ir_convert/function_converter_test.cc | 349 ++++++++++++++++++ xls/dslx/ir_convert/ir_converter.cc | 6 + xls/ir/BUILD | 7 +- xls/ir/function.cc | 30 +- xls/ir/function_test.cc | 21 ++ xls/ir/ir_parser.cc | 71 +++- xls/ir/ir_parser.h | 6 +- xls/ir/ir_parser_test.cc | 92 +++++ xls/ir/ir_scanner.cc | 9 + xls/ir/ir_scanner.h | 1 + xls/ir/xls_ir_interface.proto | 30 ++ xls/jit/BUILD | 1 + xls/jit/fuzztest_cc.tmpl | 26 +- xls/jit/jit_wrapper_generator.py | 182 ++++++++- xls/jit/jit_wrapper_generator_test.py | 142 +++++++ 22 files changed, 1195 insertions(+), 27 deletions(-) diff --git a/xls/build_rules/tests/fuzz_test_example.x b/xls/build_rules/tests/fuzz_test_example.x index 9aa4522439..d0ebe8b61d 100644 --- a/xls/build_rules/tests/fuzz_test_example.x +++ b/xls/build_rules/tests/fuzz_test_example.x @@ -12,4 +12,5 @@ // See the License for the specific language governing permissions and // limitations under the License. +#[fuzz_test(domains=`u32:0..100, u32:0..100`)] fn my_fuzz_property(x: u32, y: u32) -> bool { x + y == y + x } diff --git a/xls/dev_tools/BUILD b/xls/dev_tools/BUILD index ef5da7f5a4..e7ee4f77ee 100644 --- a/xls/dev_tools/BUILD +++ b/xls/dev_tools/BUILD @@ -119,6 +119,7 @@ cc_library( srcs = ["extract_interface.cc"], hdrs = ["extract_interface.h"], deps = [ + "//xls/common:attribute_data", "//xls/ir", "//xls/ir:channel", "//xls/ir:register", @@ -127,6 +128,7 @@ cc_library( "//xls/ir:xls_ir_interface_cc_proto", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", + "@com_google_protobuf//:protobuf", ], ) @@ -154,6 +156,7 @@ cc_test( srcs = ["extract_interface_test.cc"], deps = [ ":extract_interface", + "//xls/common:attribute_data", "//xls/common:proto_test_utils", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", diff --git a/xls/dev_tools/extract_interface.cc b/xls/dev_tools/extract_interface.cc index 6e6597989a..953c8ad792 100644 --- a/xls/dev_tools/extract_interface.cc +++ b/xls/dev_tools/extract_interface.cc @@ -14,8 +14,13 @@ #include "xls/dev_tools/extract_interface.h" +#include +#include + #include "absl/log/check.h" #include "absl/status/statusor.h" +#include "google/protobuf/text_format.h" +#include "xls/common/attribute_data.h" #include "xls/ir/block.h" #include "xls/ir/channel.h" #include "xls/ir/function.h" @@ -61,6 +66,33 @@ PackageInterfaceProto::Function ExtractFunctionInterface(Function* func) { AddNamed(proto.add_parameters(), param); } *proto.mutable_result_type() = func->GetType()->return_type()->ToProto(); + + if (func->HasAttribute(AttributeKind::kFuzzTest)) { + for (const auto& attr : func->attributes()) { + if (attr.kind() == AttributeKind::kFuzzTest) { + for (const auto& arg : attr.args()) { + CHECK(std::holds_alternative( + arg)) + << "kFuzzTest argument must be StringKeyValueArgument"; + const auto& skv = + std::get(arg); + CHECK_EQ(skv.first, "domains") + << "kFuzzTest only supports 'domains' argument"; + CHECK(skv.is_backticked) + << "kFuzzTest domains argument must be backticked"; + + std::string proto_str = skv.second; + PackageInterfaceProto::Function temp_func; + CHECK(google::protobuf::TextFormat::ParseFromString(proto_str, &temp_func)) + << "Failed to parse fuzz test domains proto from attribute"; + proto.mutable_parameter_domains()->CopyFrom( + temp_func.parameter_domains()); + } + break; + } + } + } + return proto; } diff --git a/xls/dev_tools/extract_interface_test.cc b/xls/dev_tools/extract_interface_test.cc index 77fd46e835..9808411bb0 100644 --- a/xls/dev_tools/extract_interface_test.cc +++ b/xls/dev_tools/extract_interface_test.cc @@ -15,9 +15,12 @@ #include "xls/dev_tools/extract_interface.h" #include +#include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "xls/common/attribute_data.h" #include "xls/common/proto_test_utils.h" #include "xls/common/status/matchers.h" #include "xls/ir/bits.h" @@ -164,5 +167,102 @@ TEST_F(ExtractInterfaceTest, BasicBlock) { )pb")); } +TEST_F(ExtractInterfaceTest, FuzzTestFunction) { + constexpr std::string_view kIr = R"( +package test + +#[fuzz_test(domains = `parameter_domains { range { min { bits { bit_count: 32 data: "\000" } } max { bits { bit_count: 32 data: "\012" } } } } parameter_domains { arbitrary: true }`)] +fn f(x: bits[32], y: bits[32]) -> bits[32] { + ret x: bits[32] = param(name=x) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(kIr)); + + PackageInterfaceProto proto = ExtractPackageInterface(p.get()); + + ASSERT_EQ(proto.functions().size(), 1); + const auto& func_proto = proto.functions(0); + + ASSERT_EQ(func_proto.parameter_domains().size(), 2); + EXPECT_TRUE(func_proto.parameter_domains(1).arbitrary()); + EXPECT_TRUE(func_proto.parameter_domains(0).has_range()); +} + +TEST_F(ExtractInterfaceTest, FuzzTestFunctionNoDomains) { + constexpr std::string_view kIr = R"( +package test + +#[fuzz_test] +fn f(x: bits[32]) -> bits[32] { + ret x: bits[32] = param(name=x) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(kIr)); + + PackageInterfaceProto proto = ExtractPackageInterface(p.get()); + + ASSERT_EQ(proto.functions().size(), 1); + const auto& func_proto = proto.functions(0); + + EXPECT_THAT(func_proto.parameter_domains(), testing::IsEmpty()); +} + +TEST_F(ExtractInterfaceTest, FuzzTestFunctionInvalidArgKeyManual) { + VerifiedPackage p("test_package"); + Function* f; + { + FunctionBuilder fb("f", &p); + fb.Param("x", p.GetBitsType(32)); + XLS_ASSERT_OK_AND_ASSIGN(f, fb.Build()); + } + + std::vector args; + args.push_back(AttributeData::StringKeyValueArgument{ + .first = "invalid", .second = "value", .is_backticked = true}); + f->AddAttribute(AttributeData(AttributeKind::kFuzzTest, std::move(args))); + + EXPECT_DEATH(ExtractPackageInterface(&p), + "kFuzzTest only supports 'domains' argument"); +} + +TEST_F(ExtractInterfaceTest, PackageWithMixOfFuzzAndNonFuzzFunctions) { + constexpr std::string_view kIr = R"( +package test + +#[fuzz_test(domains = `parameter_domains { arbitrary: true }`)] +fn fuzz_me(x: bits[32]) -> bits[32] { + ret x: bits[32] = param(name=x) +} + +fn dont_fuzz_me(x: bits[32]) -> bits[32] { + ret x: bits[32] = param(name=x) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(auto p, ParsePackage(kIr)); + + PackageInterfaceProto proto = ExtractPackageInterface(p.get()); + + ASSERT_EQ(proto.functions().size(), 2); + + const PackageInterfaceProto::Function* fuzz_proto = nullptr; + const PackageInterfaceProto::Function* non_fuzz_proto = nullptr; + + for (const auto& f : proto.functions()) { + if (f.base().name() == "fuzz_me") { + fuzz_proto = &f; + } else if (f.base().name() == "dont_fuzz_me") { + non_fuzz_proto = &f; + } + } + + ASSERT_NE(fuzz_proto, nullptr); + ASSERT_NE(non_fuzz_proto, nullptr); + + EXPECT_EQ(fuzz_proto->parameter_domains().size(), 1); + EXPECT_TRUE(fuzz_proto->parameter_domains(0).arbitrary()); + + EXPECT_TRUE(non_fuzz_proto->parameter_domains().empty()); +} + } // namespace } // namespace xls diff --git a/xls/dslx/ir_convert/BUILD b/xls/dslx/ir_convert/BUILD index 3b154dca43..5377f864d5 100644 --- a/xls/dslx/ir_convert/BUILD +++ b/xls/dslx/ir_convert/BUILD @@ -335,6 +335,7 @@ cc_library( ":ir_conversion_utils", ":proc_config_ir_converter", ":proc_scoped_channel_scope", + "//xls/common:attribute_data", "//xls/common:visitor", "//xls/common/status:ret_check", "//xls/common/status:status_macros", @@ -383,6 +384,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", ], ) @@ -395,6 +397,7 @@ cc_test( ":convert_options", ":function_converter", ":test_utils", + "//xls/common:attribute_data", "//xls/common:proto_test_utils", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", @@ -404,6 +407,8 @@ cc_test( "//xls/dslx/frontend:ast", "//xls/ir", "//xls/ir:xls_ir_interface_cc_proto", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", "@googletest//:gtest", ], ) diff --git a/xls/dslx/ir_convert/function_converter.cc b/xls/dslx/ir_convert/function_converter.cc index 37828697f4..3eaee4eec1 100644 --- a/xls/dslx/ir_convert/function_converter.cc +++ b/xls/dslx/ir_convert/function_converter.cc @@ -40,6 +40,8 @@ #include "absl/strings/substitute.h" #include "absl/types/span.h" #include "absl/types/variant.h" +#include "google/protobuf/text_format.h" +#include "xls/common/attribute_data.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/common/visitor.h" @@ -3580,6 +3582,12 @@ absl::Status FunctionConverter::HandleFunction( VLOG(5) << "Built function: " << ir_fn->name(); XLS_RETURN_IF_ERROR(VerifyFunction(ir_fn)); + XLS_ASSIGN_OR_RETURN(std::optional fuzz_test_attr, + LowerFuzzTestDomains(node)); + if (fuzz_test_attr.has_value()) { + ir_fn->AddAttribute(std::move(*fuzz_test_attr)); + } + // If it's a public fallible function, or it's the entry function for the // package, we make a wrapper so that the external world (e.g. JIT, verilog // module) doesn't need to take implicit token arguments. @@ -3605,6 +3613,96 @@ absl::Status FunctionConverter::HandleFunction( return absl::OkStatus(); } +absl::Status FunctionConverter::LowerDomainExpr( + Expr* expr, PackageInterfaceProto::FuzzTestDomain* proto) { + if (expr->kind() == AstNodeKind::kXlsTuple && + dynamic_cast(expr)->empty()) { + proto->set_arbitrary(true); + return absl::OkStatus(); + } + if (expr->kind() == AstNodeKind::kRange) { + Range* range_node = static_cast(expr); + + XLS_ASSIGN_OR_RETURN(InterpValue min_val, + current_type_info_->GetConstExpr(range_node->start())); + XLS_ASSIGN_OR_RETURN(InterpValue max_val, + current_type_info_->GetConstExpr(range_node->end())); + + XLS_ASSIGN_OR_RETURN(Value ir_min, InterpValueToValue(min_val)); + XLS_ASSIGN_OR_RETURN(Value ir_max, InterpValueToValue(max_val)); + + XLS_ASSIGN_OR_RETURN(ValueProto min_proto, ir_min.AsProto()); + XLS_ASSIGN_OR_RETURN(ValueProto max_proto, ir_max.AsProto()); + + auto* range_proto = proto->mutable_range(); + *range_proto->mutable_min() = std::move(min_proto); + *range_proto->mutable_max() = std::move(max_proto); + return absl::OkStatus(); + } + if (expr->kind() == AstNodeKind::kArray) { + Array* array_node = static_cast(expr); + + auto* element_of_proto = proto->mutable_element_of(); + for (Expr* member : array_node->members()) { + XLS_ASSIGN_OR_RETURN(InterpValue val, + current_type_info_->GetConstExpr(member)); + XLS_ASSIGN_OR_RETURN(Value ir_val, InterpValueToValue(val)); + XLS_ASSIGN_OR_RETURN(ValueProto val_proto, ir_val.AsProto()); + *element_of_proto->add_values() = std::move(val_proto); + } + return absl::OkStatus(); + } + if (expr->kind() == AstNodeKind::kXlsTuple) { + XlsTuple* tuple_node = static_cast(expr); + + auto* tuple_proto = proto->mutable_tuple(); + for (Expr* member : tuple_node->members()) { + XLS_RETURN_IF_ERROR(LowerDomainExpr(member, tuple_proto->add_elements())); + } + return absl::OkStatus(); + } + return absl::UnimplementedError( + absl::StrCat("Unsupported fuzztest domain type: ", expr->ToString())); +} + +absl::StatusOr> +FunctionConverter::LowerFuzzTestDomains(Function* node) { + if (node->parent() != nullptr && + node->parent()->kind() == AstNodeKind::kFuzzTestFunction) { + FuzzTestFunction* ft = static_cast(node->parent()); + + if (ft->domains().has_value()) { + XlsTuple* domains_tuple = *ft->domains(); + // We use a dummy Function proto here solely to get the + // `parameter_domains` field name wrapper in the serialized text proto. + // This will allow clients to easily parse the string back into a Function + // proto and recover the domains therein. + PackageInterfaceProto::Function temp_func; + + for (Expr* domain_expr : domains_tuple->members()) { + PackageInterfaceProto::FuzzTestDomain* domain_proto = + temp_func.add_parameter_domains(); + + XLS_RETURN_IF_ERROR(LowerDomainExpr(domain_expr, domain_proto)); + } + + std::string proto_str; + google::protobuf::TextFormat::Printer printer; + printer.SetSingleLineMode(true); + XLS_RET_CHECK(printer.PrintToString(temp_func, &proto_str)); + + std::vector args; + args.push_back( + AttributeData::StringKeyValueArgument{.first = "domains", + .second = std::move(proto_str), + .is_backticked = true}); + + return AttributeData(AttributeKind::kFuzzTest, std::move(args)); + } + } + return std::nullopt; +} + absl::Status FunctionConverter::HandleSpawn(const Spawn* node) { VLOG(5) << "HandleSpawn: " << node->ToString(); if (!options_.lower_to_proc_scoped_channels) { diff --git a/xls/dslx/ir_convert/function_converter.h b/xls/dslx/ir_convert/function_converter.h index a697860dc7..9e10b89ac3 100644 --- a/xls/dslx/ir_convert/function_converter.h +++ b/xls/dslx/ir_convert/function_converter.h @@ -31,6 +31,7 @@ #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/types/span.h" +#include "xls/common/attribute_data.h" #include "xls/dslx/channel_direction.h" #include "xls/dslx/frontend/ast.h" #include "xls/dslx/frontend/pos.h" @@ -693,6 +694,15 @@ class FunctionConverter { absl::flat_hash_map state_write_called_by_state_name_; std::vector> proc_def_instances_; + + // If the function has a kFuzzTest attribute, this method will convert the + // fuzz test domains to proto and insert it into the AttributeData for + // storage in the IR. + absl::StatusOr> LowerFuzzTestDomains( + Function* node); + + absl::Status LowerDomainExpr(Expr* expr, + PackageInterfaceProto::FuzzTestDomain* proto); }; } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/function_converter_test.cc b/xls/dslx/ir_convert/function_converter_test.cc index 72a556102b..7e39a6cc77 100644 --- a/xls/dslx/ir_convert/function_converter_test.cc +++ b/xls/dslx/ir_convert/function_converter_test.cc @@ -19,6 +19,8 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" +#include "absl/types/span.h" +#include "xls/common/attribute_data.h" #include "xls/common/proto_test_utils.h" #include "xls/common/status/matchers.h" #include "xls/dslx/create_import_data.h" @@ -579,5 +581,352 @@ TEST(FunctionConverterTest, ConvertsFunctionWithUpdate2DBuiltinEmptyTuple) { )pb")); } +TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithBasicDomains) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(R"( +#[fuzz_test(domains = `u32:0..10, ()`)] +fn f(x: u32, y: u32) -> u32 { x + y } +)", + "test_module.x", "test_module", &import_data)); + + XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft, + tm.module->GetMemberOrError("f")); + ASSERT_NE(ft, nullptr); + + Function* f = &ft->fn(); + + const ConvertOptions convert_options; + PackageConversionData package = MakeConversionData("test_module_package"); + PackageData package_data{&package}; + FunctionConverter converter(package_data, tm.module, &import_data, + convert_options, /*proc_data=*/nullptr, + /*channel_scope=*/nullptr, + /*is_top=*/true); + XLS_ASSERT_OK( + converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr)); + + ASSERT_FALSE(package_data.conversion_info->package->functions().empty()); + auto* ir_fn = + package_data.conversion_info->package->functions().front().get(); + + EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest)); + + absl::Span attributes = ir_fn->attributes(); + ASSERT_EQ(attributes.size(), 1); + const AttributeData& attr = attributes[0]; + EXPECT_EQ(attr.kind(), AttributeKind::kFuzzTest); + ASSERT_EQ(attr.args().size(), 1); + const AttributeData::Argument& arg = attr.args()[0]; + EXPECT_THAT( + arg, + testing::VariantWith( + testing::AllOf( + testing::Field(&AttributeData::StringKeyValueArgument::first, + "domains"), + testing::Field(&AttributeData::StringKeyValueArgument::second, + testing::HasSubstr("parameter_domains")), + testing::Field(&AttributeData::StringKeyValueArgument::second, + testing::HasSubstr("arbitrary: true")), + testing::Field(&AttributeData::StringKeyValueArgument::second, + testing::HasSubstr("range")), + testing::Field(&AttributeData::StringKeyValueArgument::second, + testing::HasSubstr("bit_count: 32"))))); +} + +TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithDifferentBitWidths) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(R"( +#[fuzz_test(domains = `u8:0..5, u64:0..100`)] +fn f(x: u8, y: u64) -> u64 { (x as u64) + y } +)", + "test_module.x", "test_module", &import_data)); + + XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft, + tm.module->GetMemberOrError("f")); + ASSERT_NE(ft, nullptr); + Function* f = &ft->fn(); + + const ConvertOptions convert_options; + PackageConversionData package = MakeConversionData("test_module_package"); + PackageData package_data{&package}; + FunctionConverter converter(package_data, tm.module, &import_data, + convert_options, /*proc_data=*/nullptr, + /*channel_scope=*/nullptr, + /*is_top=*/true); + XLS_ASSERT_OK( + converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr)); + + ASSERT_FALSE(package_data.conversion_info->package->functions().empty()); + auto* ir_fn = + package_data.conversion_info->package->functions().front().get(); + + EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest)); + absl::Span attributes = ir_fn->attributes(); + ASSERT_EQ(attributes.size(), 1); + const AttributeData::Argument& arg = attributes[0].args()[0]; + EXPECT_THAT( + arg, + testing::VariantWith( + testing::AllOf( + testing::Field(&AttributeData::StringKeyValueArgument::first, + "domains"), + testing::Field(&AttributeData::StringKeyValueArgument::second, + testing::HasSubstr("bit_count: 8")), + testing::Field(&AttributeData::StringKeyValueArgument::second, + testing::HasSubstr("bit_count: 64"))))); +} + +TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithRangeEdgeCases) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(R"( +#[fuzz_test(domains = `u32:10..10, u32:0..0xFFFFFFFF`)] +fn f(x: u32, y: u32) -> u32 { x + y } +)", + "test_module.x", "test_module", &import_data)); + + XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft, + tm.module->GetMemberOrError("f")); + ASSERT_NE(ft, nullptr); + Function* f = &ft->fn(); + + const ConvertOptions convert_options; + PackageConversionData package = MakeConversionData("test_module_package"); + PackageData package_data{&package}; + FunctionConverter converter(package_data, tm.module, &import_data, + convert_options, /*proc_data=*/nullptr, + /*channel_scope=*/nullptr, + /*is_top=*/true); + XLS_ASSERT_OK( + converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr)); + + ASSERT_FALSE(package_data.conversion_info->package->functions().empty()); + auto* ir_fn = + package_data.conversion_info->package->functions().front().get(); + + EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest)); + absl::Span attributes = ir_fn->attributes(); + ASSERT_EQ(attributes.size(), 1); + const AttributeData::Argument& arg = attributes[0].args()[0]; + EXPECT_THAT( + arg, + testing::VariantWith( + testing::AllOf( + testing::Field(&AttributeData::StringKeyValueArgument::first, + "domains"), + testing::Field(&AttributeData::StringKeyValueArgument::second, + testing::HasSubstr("bit_count: 32"))))); +} + +TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithMultipleSameDomains) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(R"( +#[fuzz_test(domains = `u32:0..10, u32:20..30`)] +fn f(x: u32, y: u32) -> u32 { x + y } +)", + "test_module.x", "test_module", &import_data)); + + XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft, + tm.module->GetMemberOrError("f")); + ASSERT_NE(ft, nullptr); + Function* f = &ft->fn(); + + const ConvertOptions convert_options; + PackageConversionData package = MakeConversionData("test_module_package"); + PackageData package_data{&package}; + FunctionConverter converter(package_data, tm.module, &import_data, + convert_options, /*proc_data=*/nullptr, + /*channel_scope=*/nullptr, + /*is_top=*/true); + XLS_ASSERT_OK( + converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr)); + + ASSERT_FALSE(package_data.conversion_info->package->functions().empty()); + auto* ir_fn = + package_data.conversion_info->package->functions().front().get(); + + EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest)); + absl::Span attributes = ir_fn->attributes(); + ASSERT_EQ(attributes.size(), 1); + const AttributeData::Argument& arg = attributes[0].args()[0]; + EXPECT_THAT( + arg, + testing::VariantWith( + testing::AllOf( + testing::Field(&AttributeData::StringKeyValueArgument::first, + "domains"), + testing::Field(&AttributeData::StringKeyValueArgument::second, + testing::HasSubstr("parameter_domains"))))); +} + +TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithTupleDomain) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(R"( +#[fuzz_test(domains = `(u32:0..10, u32:0..20)`)] +fn f(x: (u32, u32)) -> u32 { x.0 } +)", + "test_module.x", "test_module", &import_data)); + + XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft, + tm.module->GetMemberOrError("f")); + ASSERT_NE(ft, nullptr); + Function* f = &ft->fn(); + + const ConvertOptions convert_options; + PackageConversionData package = MakeConversionData("test_module_package"); + PackageData package_data{.conversion_info = &package}; + FunctionConverter converter(package_data, tm.module, &import_data, + convert_options, /*proc_data=*/nullptr, + /*channel_scope=*/nullptr, + /*is_top=*/true); + + XLS_ASSERT_OK( + converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr)); + + ASSERT_FALSE(package_data.conversion_info->package->functions().empty()); + auto* ir_fn = + package_data.conversion_info->package->functions().front().get(); + + EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest)); + absl::Span attributes = ir_fn->attributes(); + const auto& skv = + std::get(attributes[0].args()[0]); + + EXPECT_THAT(skv.second, testing::HasSubstr("parameter_domains")); + EXPECT_THAT(skv.second, testing::HasSubstr("tuple")); + EXPECT_THAT(skv.second, testing::HasSubstr("range")); +} + +TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithElementOfDomain) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(R"( +#[fuzz_test(domains = `[u32:1, u32:2, u32:3]`)] +fn f(x: u32) -> u32 { x } +)", + "test_module.x", "test_module", &import_data)); + + XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft, + tm.module->GetMemberOrError("f")); + ASSERT_NE(ft, nullptr); + Function* f = &ft->fn(); + + const ConvertOptions convert_options; + PackageConversionData package = MakeConversionData("test_module_package"); + PackageData package_data{&package}; + FunctionConverter converter(package_data, tm.module, &import_data, + convert_options, /*proc_data=*/nullptr, + /*channel_scope=*/nullptr, + /*is_top=*/true); + + XLS_ASSERT_OK( + converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr)); + + ASSERT_FALSE(package_data.conversion_info->package->functions().empty()); + auto* ir_fn = + package_data.conversion_info->package->functions().front().get(); + + EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest)); + absl::Span attributes = ir_fn->attributes(); + const auto& skv = + std::get(attributes[0].args()[0]); + + EXPECT_THAT(skv.second, testing::HasSubstr("parameter_domains")); + EXPECT_THAT(skv.second, testing::HasSubstr("element_of")); +} + +TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithEnumElementOfDomain) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(R"( +enum MyEnum : u32 { + A = 1, + B = 2, +} +#[fuzz_test(domains = `[MyEnum::A, MyEnum::B]`)] +fn f(x: MyEnum) -> MyEnum { x } +)", + "test_module.x", "test_module", &import_data)); + + XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft, + tm.module->GetMemberOrError("f")); + ASSERT_NE(ft, nullptr); + Function* f = &ft->fn(); + + const ConvertOptions convert_options; + PackageConversionData package = MakeConversionData("test_module_package"); + PackageData package_data{&package}; + FunctionConverter converter(package_data, tm.module, &import_data, + convert_options, /*proc_data=*/nullptr, + /*channel_scope=*/nullptr, + /*is_top=*/true); + + XLS_ASSERT_OK( + converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr)); + + ASSERT_FALSE(package_data.conversion_info->package->functions().empty()); + auto* ir_fn = + package_data.conversion_info->package->functions().front().get(); + + EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest)); + absl::Span attributes = ir_fn->attributes(); + const auto& skv = + std::get(attributes[0].args()[0]); + + EXPECT_THAT(skv.second, testing::HasSubstr("parameter_domains")); + EXPECT_THAT(skv.second, testing::HasSubstr("element_of")); +} + +TEST(FunctionConverterTest, ConvertsFuzzTestFunctionWithNestedTupleDomain) { + ImportData import_data = CreateImportDataForTest(); + XLS_ASSERT_OK_AND_ASSIGN( + TypecheckedModule tm, + ParseAndTypecheck(R"( +#[fuzz_test(domains = `u32:0..10, ((), u32:0..5)`)] +fn f(x: u32, y: ((), u32)) -> u32 { x } +)", + "test_module.x", "test_module", &import_data)); + + XLS_ASSERT_OK_AND_ASSIGN(FuzzTestFunction * ft, + tm.module->GetMemberOrError("f")); + ASSERT_NE(ft, nullptr); + Function* f = &ft->fn(); + + const ConvertOptions convert_options; + PackageConversionData package = MakeConversionData("test_module_package"); + PackageData package_data{&package}; + FunctionConverter converter(package_data, tm.module, &import_data, + convert_options, /*proc_data=*/nullptr, + /*channel_scope=*/nullptr, + /*is_top=*/true); + + XLS_ASSERT_OK( + converter.HandleFunction(f, tm.type_info, /*parametric_env=*/nullptr)); + + ASSERT_FALSE(package_data.conversion_info->package->functions().empty()); + auto* ir_fn = + package_data.conversion_info->package->functions().front().get(); + + EXPECT_TRUE(ir_fn->HasAttribute(AttributeKind::kFuzzTest)); + absl::Span attributes = ir_fn->attributes(); + const auto& skv = + std::get(attributes[0].args()[0]); + + EXPECT_THAT(skv.second, testing::HasSubstr("parameter_domains")); + EXPECT_THAT(skv.second, testing::HasSubstr("tuple")); + EXPECT_THAT(skv.second, testing::HasSubstr("arbitrary: true")); +} + } // namespace } // namespace xls::dslx diff --git a/xls/dslx/ir_convert/ir_converter.cc b/xls/dslx/ir_convert/ir_converter.cc index 130198636e..3ef73168c1 100644 --- a/xls/dslx/ir_convert/ir_converter.cc +++ b/xls/dslx/ir_convert/ir_converter.cc @@ -522,6 +522,12 @@ absl::Status ConvertOneFunctionIntoPackage(Module* module, return ConvertOneFunctionIntoPackageInternal(&(*test_fn)->fn(), import_data, options, conv); } + absl::StatusOr fuzz_test_fn = + module->GetMemberOrError(entry_function_name); + if (fuzz_test_fn.ok()) { + return ConvertOneFunctionIntoPackageInternal(&(*fuzz_test_fn)->fn(), + import_data, options, conv); + } std::optional fn_or = module->GetFunction(entry_function_name); if (fn_or.has_value()) { diff --git a/xls/ir/BUILD b/xls/ir/BUILD index 3acf6fea70..2f9ac7b4da 100644 --- a/xls/ir/BUILD +++ b/xls/ir/BUILD @@ -895,6 +895,7 @@ cc_library( ":verifier", "//xls/codegen:module_signature", "//xls/codegen:module_signature_cc_proto", + "//xls/common:attribute_data", "//xls/common:visitor", "//xls/common/status:ret_check", "//xls/common/status:status_macros", @@ -928,6 +929,7 @@ cc_test( ":state_element", ":type", ":value", + "//xls/common:attribute_data", "//xls/common:source_location", "//xls/common:xls_gunit_main", "//xls/common/status:matchers", @@ -1570,7 +1572,10 @@ py_proto_library( proto_library( name = "xls_ir_interface_proto", srcs = ["xls_ir_interface.proto"], - deps = [":xls_type_proto"], + deps = [ + ":xls_type_proto", + ":xls_value_proto", + ], ) cc_proto_library( diff --git a/xls/ir/function.cc b/xls/ir/function.cc index c99e0e4389..45157be200 100644 --- a/xls/ir/function.cc +++ b/xls/ir/function.cc @@ -32,9 +32,11 @@ #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "absl/types/span.h" +#include "absl/types/variant.h" #include "xls/common/attribute_data.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" +#include "xls/common/visitor.h" #include "xls/ir/change_listener.h" #include "xls/ir/function_base.h" #include "xls/ir/ir_annotator.h" @@ -348,8 +350,32 @@ std::vector Function::AttributeIrStrings() const { attribute_strings.push_back("non_synth"); } for (const auto& attr : attributes_) { - // TODO - davidplass: Properly serialize AttributeData arguments. - attribute_strings.push_back(AttributeKindToString(attr.kind())); + std::string attr_str = AttributeKindToString(attr.kind()); + if (!attr.args().empty()) { + std::vector arg_strings; + for (const auto& arg : attr.args()) { + std::string arg_str = absl::visit( + Visitor{ + [](const std::string& s) { return s; }, + [](const AttributeData::StringLiteralArgument& s) { + return absl::StrCat("\"", s.text, "\""); + }, + [](const AttributeData::StringKeyValueArgument& s) { + if (s.is_backticked) { + return absl::StrCat(s.first, "=`", s.second, "`"); + } + return absl::StrCat(s.first, "=\"", s.second, "\""); + }, + [](const AttributeData::IntKeyValueArgument& s) { + return absl::StrCat(s.first, "=", s.second); + }, + }, + arg); + arg_strings.push_back(arg_str); + } + absl::StrAppend(&attr_str, "(", absl::StrJoin(arg_strings, ", "), ")"); + } + attribute_strings.push_back(attr_str); } return attribute_strings; } diff --git a/xls/ir/function_test.cc b/xls/ir/function_test.cc index 1f11e09183..1615c2ad8b 100644 --- a/xls/ir/function_test.cc +++ b/xls/ir/function_test.cc @@ -642,5 +642,26 @@ TEST_F(FunctionTest, AttributeDataTest) { EXPECT_THAT(f->DumpIr(), HasSubstr("#[fuzz_test]")); } +TEST_F(FunctionTest, AttributeSerializationWithArgumentsTest) { + auto p = CreatePackage(); + FunctionBuilder b("f", p.get()); + b.Param("x", p->GetBitsType(32)); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, b.BuildWithReturnValue(b.Tuple({}))); + + std::vector args; + args.push_back("ident"); + args.push_back(AttributeData::StringLiteralArgument{.text = "literal"}); + args.push_back(AttributeData::StringKeyValueArgument{ + .first = "key", .second = "backticked", .is_backticked = true}); + args.push_back(AttributeData::IntKeyValueArgument{"int_key", 42}); + + f->AddAttribute(AttributeData(AttributeKind::kFuzzTest, args)); + + EXPECT_THAT( + f->DumpIr(), + HasSubstr( + "#[fuzz_test(ident, \"literal\", key=`backticked`, int_key=42)]")); +} + } // namespace } // namespace xls diff --git a/xls/ir/ir_parser.cc b/xls/ir/ir_parser.cc index 72b7dbcaf1..7566f1436e 100644 --- a/xls/ir/ir_parser.cc +++ b/xls/ir/ir_parser.cc @@ -44,6 +44,7 @@ #include "google/protobuf/text_format.h" #include "xls/codegen/module_signature.h" #include "xls/codegen/module_signature.pb.h" +#include "xls/common/attribute_data.h" #include "xls/common/status/ret_check.h" #include "xls/common/status/status_macros.h" #include "xls/common/visitor.h" @@ -2386,21 +2387,39 @@ absl::StatusOr Parser::ParseFunctionInternal( } for (const IrAttribute& attribute : outer_attributes) { - if (std::holds_alternative(attribute.payload)) { - result->SetInitiationInterval( - std::get(attribute.payload).value); - } else if (std::holds_alternative(attribute.payload)) { - // Dummy parse to make sure it is a valid template. - const ForeignFunctionData& ffi = - std::get(attribute.payload); - XLS_RETURN_IF_ERROR(CodeTemplate::Create(ffi.code_template()).status()); - result->SetForeignFunctionData(ffi); - } else if (std::holds_alternative(attribute.payload)) { - result->set_non_synth(true); - } else { - return absl::InvalidArgumentError(absl::StrFormat( - "Invalid attribute for function: %s", attribute.name)); - } + XLS_RETURN_IF_ERROR(absl::visit( + Visitor{[&](const InitiationInterval& ii) -> absl::Status { + result->SetInitiationInterval(ii.value); + return absl::OkStatus(); + }, + [&](const ForeignFunctionData& ffi) -> absl::Status { + // Dummy parse to make sure it is a valid template. + XLS_RETURN_IF_ERROR( + CodeTemplate::Create(ffi.code_template()).status()); + result->SetForeignFunctionData(ffi); + return absl::OkStatus(); + }, + [&](const NonSynthMarker&) -> absl::Status { + result->set_non_synth(true); + return absl::OkStatus(); + }, + [&](const FuzzTestAttribute& fuzz_test_attr) -> absl::Status { + std::vector args; + if (fuzz_test_attr.domains_proto_str.has_value()) { + args.push_back(AttributeData::StringKeyValueArgument{ + .first = "domains", + .second = *fuzz_test_attr.domains_proto_str, + .is_backticked = true}); + } + result->AddAttribute( + AttributeData(AttributeKind::kFuzzTest, std::move(args))); + return absl::OkStatus(); + }, + [&](const auto&) -> absl::Status { + return absl::InvalidArgumentError(absl::StrFormat( + "Invalid attribute for function: %s", attribute.name)); + }}, + attribute.payload)); } return result; @@ -3020,6 +3039,28 @@ absl::StatusOr Parser::ParseAttribute(Package* package) { scanner_.DropTokenOrError(LexicalTokenType::kParenClose)); return IrAttribute{.name = attribute_name.value(), .payload = ffi}; } + if (attribute_name.value() == "fuzz_test") { + std::optional domains_proto_str; + + if (scanner_.TryDropToken(LexicalTokenType::kParenOpen)) { + absl::flat_hash_map> handlers; + handlers["domains"] = [&]() -> absl::Status { + XLS_ASSIGN_OR_RETURN( + Token token, + scanner_.PopTokenOrError(LexicalTokenType::kBacktickedString, + "backticked string")); + domains_proto_str = token.value(); + return absl::OkStatus(); + }; + XLS_RETURN_IF_ERROR(ParseKeywordArguments(handlers)); + XLS_RETURN_IF_ERROR( + scanner_.DropTokenOrError(LexicalTokenType::kParenClose)); + } + + return IrAttribute{ + .name = attribute_name.value(), + .payload = FuzzTestAttribute{.domains_proto_str = domains_proto_str}}; + } if (attribute_name.value() == "channel_ports") { std::optional channel_name; std::optional type; diff --git a/xls/ir/ir_parser.h b/xls/ir/ir_parser.h index 4e7ab138bf..2cd21afeba 100644 --- a/xls/ir/ir_parser.h +++ b/xls/ir/ir_parser.h @@ -74,10 +74,14 @@ struct ResetAttribute { struct NonSynthMarker : public std::monostate {}; +struct FuzzTestAttribute { + std::optional domains_proto_str; +}; + using IrAttributePayload = std::variant; + NonSynthMarker, FuzzTestAttribute>; struct IrAttribute { std::string name; IrAttributePayload payload; diff --git a/xls/ir/ir_parser_test.cc b/xls/ir/ir_parser_test.cc index 9049b833df..5a1e7ec2c8 100644 --- a/xls/ir/ir_parser_test.cc +++ b/xls/ir/ir_parser_test.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -28,6 +29,7 @@ #include "absl/strings/str_cat.h" #include "absl/strings/substitute.h" #include "absl/types/span.h" +#include "xls/common/attribute_data.h" #include "xls/common/source_location.h" #include "xls/common/status/matchers.h" #include "xls/ir/bits.h" @@ -1206,4 +1208,94 @@ scheduled_block b() { EXPECT_EQ(block->stages().size(), 1); } +TEST(IrParserTest, ParseAttributeWithArguments) { + std::string program = R"( +package test + +#[fuzz_test(domains = `u32:0..100, ()`)] +fn main(x: bits[32]) -> bits[32] { + ret x: bits[32] = param(name=x) +} +)"; + XLS_ASSERT_OK_AND_ASSIGN(auto package, Parser::ParsePackage(program)); + EXPECT_EQ(package->functions().size(), 1); + Function* func = package->functions().front().get(); + EXPECT_TRUE(func->HasAttribute(AttributeKind::kFuzzTest)); + absl::Span attributes = func->attributes(); + ASSERT_EQ(attributes.size(), 1); + const AttributeData& attr = attributes[0]; + EXPECT_EQ(attr.kind(), AttributeKind::kFuzzTest); + ASSERT_EQ(attr.args().size(), 1); + const AttributeData::Argument& arg = attr.args()[0]; + ASSERT_TRUE( + std::holds_alternative(arg)); + const auto& skv = std::get(arg); + EXPECT_EQ(skv.first, "domains"); + EXPECT_EQ(skv.second, "u32:0..100, ()"); + EXPECT_TRUE(skv.is_backticked); +} + +TEST(IrParserTest, ParseAttributeWithBrokenBacktick) { + std::string program = R"( +package test + +#[fuzz_test(domains = `u32:0..100, ())] +fn main(x: bits[32]) -> bits[32] { + ret x: bits[32] = param(name=x) +} +)"; + EXPECT_THAT(Parser::ParsePackage(program).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Unterminated quoted string"))); +} + +TEST(IrParserTest, ParseAttributeEmptyFuzzTest) { + std::string program = R"ir( +package test + +#[fuzz_test] +fn main(x: bits[32]) -> bits[32] { + ret x: bits[32] = param(name=x) +} +)ir"; + XLS_ASSERT_OK_AND_ASSIGN(auto package, Parser::ParsePackage(program)); + EXPECT_EQ(package->functions().size(), 1); + Function* func = package->functions().front().get(); + EXPECT_TRUE(func->HasAttribute(AttributeKind::kFuzzTest)); + absl::Span attributes = func->attributes(); + ASSERT_EQ(attributes.size(), 1); + const AttributeData& attr = attributes[0]; + EXPECT_EQ(attr.kind(), AttributeKind::kFuzzTest); + EXPECT_TRUE(attr.args().empty()); +} + +TEST(IrParserTest, ParseAttributeExtraArgument) { + std::string program = R"ir( +package test + +#[fuzz_test(domains = `u32:0..100, ()`, extra = 42)] +fn main(x: bits[32]) -> bits[32] { + ret x: bits[32] = param(name=x) +} +)ir"; + EXPECT_THAT(Parser::ParsePackage(program).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Invalid keyword argument"))); +} + +TEST(IrParserTest, ParseAttributeInvalidDomainType) { + std::string program = R"ir( +package test + +#[fuzz_test(domains = "u32:0..100, ()")] +fn main(x: bits[32]) -> bits[32] { + ret x: bits[32] = param(name=x) +} +)ir"; + EXPECT_THAT( + Parser::ParsePackage(program).status(), + StatusIs(absl::StatusCode::kInvalidArgument, + HasSubstr("Expected token of type \"backticked string\""))); +} + } // namespace xls diff --git a/xls/ir/ir_scanner.cc b/xls/ir/ir_scanner.cc index 142f1c17c1..48055705bb 100644 --- a/xls/ir/ir_scanner.cc +++ b/xls/ir/ir_scanner.cc @@ -74,6 +74,8 @@ std::string LexicalTokenTypeToString(LexicalTokenType token_type) { return "("; case LexicalTokenType::kQuotedString: return "quoted string"; + case LexicalTokenType::kBacktickedString: + return "backticked string"; case LexicalTokenType::kRightArrow: return "->"; case LexicalTokenType::kHash: @@ -304,6 +306,13 @@ class Tokenizer { start_lineno, start_colno)); continue; } + XLS_ASSIGN_OR_RETURN(content, + MatchQuotedString("`", /*allow_multiline=*/false)); + if (content.has_value()) { + tokens.push_back(Token(LexicalTokenType::kBacktickedString, + content.value(), start_lineno, start_colno)); + continue; + } // Handle single-character tokens. LexicalTokenType token_type; diff --git a/xls/ir/ir_scanner.h b/xls/ir/ir_scanner.h index 12ca341a1b..81b13e740f 100644 --- a/xls/ir/ir_scanner.h +++ b/xls/ir/ir_scanner.h @@ -53,6 +53,7 @@ enum class LexicalTokenType { kRightArrow, kHash, kBang, + kBacktickedString, }; std::string LexicalTokenTypeToString(LexicalTokenType token_type); diff --git a/xls/ir/xls_ir_interface.proto b/xls/ir/xls_ir_interface.proto index 34e734a476..91e5be5d5a 100644 --- a/xls/ir/xls_ir_interface.proto +++ b/xls/ir/xls_ir_interface.proto @@ -17,6 +17,7 @@ syntax = "proto3"; package xls; import "xls/ir/xls_type.proto"; +import "xls/ir/xls_value.proto"; message PackageInterfaceProto { // A generic thing with a name and a type. @@ -59,6 +60,10 @@ message PackageInterfaceProto { optional TypeProto result_type = 3; // If present the corresponding sv type for the result of this function. optional string sv_result_type = 4; + + // The structured fuzzing domains for this function. There may be + // either zero domains, or exactly one domain per parameter. + repeated FuzzTestDomain parameter_domains = 5; } message Proc { @@ -87,6 +92,31 @@ message PackageInterfaceProto { repeated NamedValue output_ports = 4; } + message FuzzTestDomain { + message Range { + optional ValueProto min = 1; + optional ValueProto max = 2; + } + message ElementOf { + repeated ValueProto values = 1; + } + message Tuple { + repeated FuzzTestDomain elements = 1; + } + message Array { + optional FuzzTestDomain element_domain = 1; + optional int64 size = 2; + } + + oneof domain_kind { + bool arbitrary = 1; + Range range = 2; + ElementOf element_of = 3; + Tuple tuple = 4; + Array array = 5; + } + } + // Name of the overall package. optional string name = 1; diff --git a/xls/jit/BUILD b/xls/jit/BUILD index 8289daee73..7efe702dda 100644 --- a/xls/jit/BUILD +++ b/xls/jit/BUILD @@ -411,6 +411,7 @@ pytype_strict_contrib_test( deps = [ ":jit_wrapper_generator", "//xls/common:runfiles", + "//xls/ir:xls_ir_interface_py_pb2", "//xls/ir:xls_type_py_pb2", "@abseil-py//absl/testing:absltest", "@xls_pip_deps//jinja2", diff --git a/xls/jit/fuzztest_cc.tmpl b/xls/jit/fuzztest_cc.tmpl index 14061bf52e..546af21a8c 100644 --- a/xls/jit/fuzztest_cc.tmpl +++ b/xls/jit/fuzztest_cc.tmpl @@ -9,14 +9,29 @@ namespace {{ fuzztest.namespace }} { namespace { -void {{fuzztest.property_function_name}}({{ fuzztest.params | map("property_param") | join (", ") }}) { +void {{fuzztest.property_function_name}}({% for param in fuzztest.params %}{{param.cpp_type}} {{param.name}}{% if not loop.last %}, {% endif %}{% endfor %}) { XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<{{fuzztest.lib_class_name}}> f, {{fuzztest.lib_class_name}}::Create()); + {% if fuzztest.can_be_specialized %} + // Use specialized overload. + XLS_ASSERT_OK_AND_ASSIGN(auto res_native, f->Run( + {% for param in fuzztest.params %}{{param.name}}{% if not loop.last %}, {% endif %}{% endfor %} + )); + xls::Value result = xls::Value(xls::UBits(res_native, {{fuzztest.result_width}})); + {% else %} + // Fallback: Manual binding for the standard Run(xls::Value...) overload. + {% for param in fuzztest.params %} + {% if param.is_native %} + xls::Value {{param.name}}_val = {{param.conversion_snippet}}; + {% endif %} + {% endfor %} XLS_ASSERT_OK_AND_ASSIGN(xls::Value result, f->Run( {% for param in fuzztest.params %} - {{param.name}}{% if not loop.last %}, {% endif %} - {% endfor %})); + {{param.name}}{% if param.is_native %}_val{% endif %}{% if not loop.last %}, {% endif %} + {% endfor %} + )); + {% endif %} {% if fuzztest.return_type %} ASSERT_TRUE(result.IsAllOnes()); @@ -26,7 +41,10 @@ void {{fuzztest.property_function_name}}({{ fuzztest.params | map("property_para FUZZ_TEST({{fuzztest.fuzztest_name}}, {{fuzztest.property_function_name}}) .WithDomains( {% for param in fuzztest.params %} - xls::ArbitraryValue({{fuzztest.lib_class_name}}::GetParamType({{param.index}}).value()){% if not loop.last %}, + {% if param.domain_snippet %} + {{param.domain_snippet}}{% if not loop.last %}, {% endif %} + {% else %} + xls::ArbitraryValue({{fuzztest.lib_class_name}}::GetParamType({{param.index}}).value()){% if not loop.last %}, {% endif %} {% endif %} {% endfor %} ); diff --git a/xls/jit/jit_wrapper_generator.py b/xls/jit/jit_wrapper_generator.py index 49759cc165..c138a15354 100644 --- a/xls/jit/jit_wrapper_generator.py +++ b/xls/jit/jit_wrapper_generator.py @@ -27,6 +27,16 @@ from xls.jit import aot_entrypoint_pb2 [email protected](frozen=True) +class FuzzTestInfo: + """FuzzTest specific information for a value.""" + + domain_snippet: Optional[str] = None + domain_proto: Optional[ + ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain + ] = None + + @dataclasses.dataclass(frozen=True) class XlsNamedValue: """A Named & typed value for the wrapped function/proc.""" @@ -36,6 +46,7 @@ class XlsNamedValue: unpacked_type: str specialized_type: Optional[str] type_proto: type_pb2.TypeProto + fuzztest_info: Optional[FuzzTestInfo] = None @property def value_arg(self): @@ -123,6 +134,10 @@ class PropertyFunctionParam: name: str # The index of the parameter in the function signature. index: int + domain_snippet: Optional[str] = None + cpp_type: str = "xls::Value" + is_native: bool = False + conversion_snippet: Optional[str] = None @dataclasses.dataclass(frozen=True) @@ -137,6 +152,8 @@ class PropertyFunction: # Function params and result. params: Sequence[PropertyFunctionParam] return_type: bool + can_be_specialized: bool = False + result_width: int = 0 def to_packed(t: type_pb2.TypeProto) -> str: @@ -263,6 +280,110 @@ def to_specialized( return None +def extract_int_from_bytes(data: bytes) -> int: + return int.from_bytes(data, byteorder="little") + + +# Returns true if the range domain fits within a 64-bit unsigned integer. +def can_use_uint64_range( + t: type_pb2.TypeProto, + d: Optional[ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain], +) -> bool: + if t.type_enum == type_pb2.TypeProto.BITS and d and d.HasField("range"): + max_val = extract_int_from_bytes(d.range.max.bits.data) + return max_val < (1 << 64) + return False + + +def to_domain( + t: type_pb2.TypeProto, + d: Optional[ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain], +) -> Optional[str]: + """Converts an XLS type and domain spec to a FuzzTest domain string. + + Args: + t: The XLS type proto. + d: The optional FuzzTest domain specification from the package interface. + + Returns: + A string representing the C++ FuzzTest domain (e.g., + "fuzztest::Arbitrary()"), or None if it should fallback to + xls::ArbitraryValue. + + Raises: + app.UsageError: If the domain specification is invalid or unsupported for + the given type. + """ + if ( + d is None + or d.HasField("arbitrary") + or ( + not d.HasField("range") + and not d.HasField("element_of") + and not d.HasField("tuple") + ) + ): + if t.type_enum == type_pb2.TypeProto.BITS: + c_type = to_specialized(t, int_only=True) + if c_type is None: + return None + if t.bit_count in (8, 16, 32, 64): + return f"fuzztest::Arbitrary<{c_type}>()" + else: + max_val = (1 << t.bit_count) - 1 + return f"fuzztest::InRange<{c_type}>(0, {max_val})" + elif t.type_enum == type_pb2.TypeProto.TUPLE: + elems = [to_domain(e, None) for e in t.tuple_elements] + if any(e is None for e in elems): + return None + return f"fuzztest::TupleOf({', '.join(elems)})" + elif t.type_enum == type_pb2.TypeProto.ARRAY: + elem_domain = to_domain(t.array_element, None) + if elem_domain is None: + return None + return f"fuzztest::VectorOf({elem_domain}).WithSize({t.array_size})" + else: + return None + + if d.HasField("range"): + c_type = to_specialized(t, int_only=True) + if c_type is None: + if can_use_uint64_range(t, d): + c_type = "uint64_t" + else: + raise app.UsageError( + "Range domain only supported for specializable bits types or ranges" + " fitting in 64 bits" + ) + min_val = extract_int_from_bytes(d.range.min.bits.data) + max_val = extract_int_from_bytes(d.range.max.bits.data) + return f"fuzztest::InRange<{c_type}>({min_val}, {max_val})" + + if d.HasField("element_of"): + c_type = to_specialized(t, int_only=True) + if c_type is None: + raise app.UsageError( + "ElementOf domain only supported for specializable bits types in" + " this CL" + ) + vals = [] + for v in d.element_of.values: + vals.append(str(extract_int_from_bytes(v.bits.data))) + return f"fuzztest::ElementOf(std::vector<{c_type}>{{{', '.join(vals)}}})" + + if d.HasField("tuple"): + if t.type_enum != type_pb2.TypeProto.TUPLE: + raise app.UsageError("Tuple domain requires Tuple type") + if len(d.tuple.elements) != len(t.tuple_elements): + raise app.UsageError("Tuple domain and type element count mismatch") + elems = [ + to_domain(te, de) for te, de in zip(t.tuple_elements, d.tuple.elements) + ] + return f"fuzztest::TupleOf({', '.join(elems)})" + + raise app.UsageError(f"Unsupported domain: {d}") + + def to_chan( c: ir_interface_pb2.PackageInterfaceProto.Channel, package_name: str ) -> XlsChannel: @@ -278,13 +399,21 @@ def to_chan( def to_param( p: ir_interface_pb2.PackageInterfaceProto.NamedValue, + d: Optional[ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain] = None, ) -> XlsNamedValue: + fuzztest_info = None + if d is not None: + fuzztest_info = FuzzTestInfo( + domain_snippet=to_domain(p.type, d), domain_proto=d + ) + return XlsNamedValue( name=p.name, packed_type=to_packed(p.type), unpacked_type=to_unpacked(p.type), specialized_type=to_specialized(p.type), type_proto=p.type, + fuzztest_info=fuzztest_info, ) @@ -330,7 +459,12 @@ def interpret_function_interface( """ if func_ir.base.name != aot_info.entrypoint[0].xls_function_identifier: raise app.UsageError("Aot info is for a different function.") - params = [to_param(p) for p in func_ir.parameters] + params = [] + for idx, p in enumerate(func_ir.parameters): + d = None + if idx < len(func_ir.parameter_domains): + d = func_ir.parameter_domains[idx] + params.append(to_param(p, d)) result = XlsNamedValue( name="result", packed_type=to_packed(func_ir.result_type), @@ -458,16 +592,52 @@ def wrapped_to_fuzztest( params = [] if wrapped.params: for idx, p in enumerate(wrapped.params): - params.append(PropertyFunctionParam(name=p.name, index=idx)) + is_native = p.specialized_type is not None + cpp_type = p.specialized_type or "xls::Value" + conversion_snippet = None + + domain_proto = p.fuzztest_info.domain_proto if p.fuzztest_info else None + domain_snippet = ( + p.fuzztest_info.domain_snippet if p.fuzztest_info else None + ) + + if can_use_uint64_range(p.type_proto, domain_proto): + cpp_type = "uint64_t" + is_native = True + + if is_native and p.type_proto.type_enum == type_pb2.TypeProto.BITS: + conversion_snippet = ( + f"xls::Value(xls::UBits({p.name}, {p.type_proto.bit_count}))" + ) + + params.append( + PropertyFunctionParam( + name=p.name, + index=idx, + domain_snippet=domain_snippet, + cpp_type=cpp_type, + is_native=is_native, + conversion_snippet=conversion_snippet, + ) + ) + + result_width = 0 + if ( + wrapped.result + and wrapped.result.type_proto.type_enum == type_pb2.TypeProto.BITS + ): + result_width = wrapped.result.type_proto.bit_count + return PropertyFunction( fuzztest_name=wrapped.function_name + "_fuzztest", property_function_name=wrapped.function_name, lib_class_name=lib_class_name, lib_header_path=lib_header_path, - # Everything shares the namespace namespace=wrapped.namespace, params=params, return_type=wrapped.result is not None, + can_be_specialized=wrapped.can_be_specialized, + result_width=result_width, ) @@ -479,7 +649,11 @@ def render_fuzztest( lib_header_path: str, ) -> str: """Renders the fuzztest C++ code.""" - env.filters["property_param"] = lambda p: "xls::Value " + p.name + env.filters["property_param"] = lambda p: ( + f"const {p.c_type}& {p.name}" + if p.c_type == "xls::Value" + else f"{p.c_type} {p.name}" + ) cc_template = env.from_string(cc_template_content) fuzztest = wrapped_to_fuzztest(wrapped, lib_class_name, lib_header_path) bindings = {"fuzztest": fuzztest, "len": len} diff --git a/xls/jit/jit_wrapper_generator_test.py b/xls/jit/jit_wrapper_generator_test.py index 9a025c5019..baa3e0163f 100644 --- a/xls/jit/jit_wrapper_generator_test.py +++ b/xls/jit/jit_wrapper_generator_test.py @@ -16,6 +16,7 @@ from absl.testing import absltest from xls.common import runfiles +from xls.ir import xls_ir_interface_pb2 as ir_interface_pb2 from xls.ir import xls_type_pb2 as type_pb2 from xls.jit import jit_wrapper_generator @@ -576,5 +577,146 @@ def test_render_fuzztest_tuple_mixed(self): ) +class JitWrapperGeneratorToDomainTest(absltest.TestCase): + + def test_extract_int_from_bytes(self): + self.assertEqual(jit_wrapper_generator.extract_int_from_bytes(b'\x00'), 0) + self.assertEqual(jit_wrapper_generator.extract_int_from_bytes(b'\x0a'), 10) + self.assertEqual( + jit_wrapper_generator.extract_int_from_bytes(b'\xff\xff\xff\xff'), + 0xFFFFFFFF, + ) + + def test_bits_domain_power_of_2(self): + u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) + self.assertEqual( + jit_wrapper_generator.to_domain(u32, None), + 'fuzztest::Arbitrary()', + ) + + def test_bits_domain_non_power_of_2(self): + u17 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=17) + self.assertEqual( + jit_wrapper_generator.to_domain(u17, None), + 'fuzztest::InRange(0, 131071)', + ) + + def test_bits_domain_too_wide(self): + u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128) + self.assertIsNone(jit_wrapper_generator.to_domain(u128, None)) + + def test_range_domain(self): + u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + d.range.min.bits.bit_count = 32 + d.range.min.bits.data = b'\x01' + d.range.max.bits.bit_count = 32 + d.range.max.bits.data = b'\x0a' + self.assertEqual( + jit_wrapper_generator.to_domain(u32, d), + 'fuzztest::InRange(1, 10)', + ) + + def test_range_domain_wide_bits_fits(self): + u128 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=128) + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + d.range.min.bits.bit_count = 128 + d.range.min.bits.data = b'\x01' + d.range.max.bits.bit_count = 128 + d.range.max.bits.data = b'\x0a' + self.assertEqual( + jit_wrapper_generator.to_domain(u128, d), + 'fuzztest::InRange(1, 10)', + ) + + def test_element_of_domain(self): + u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + v1 = d.element_of.values.add() + v1.bits.bit_count = 32 + v1.bits.data = b'\x01' + v2 = d.element_of.values.add() + v2.bits.bit_count = 32 + v2.bits.data = b'\x02' + self.assertEqual( + jit_wrapper_generator.to_domain(u32, d), + 'fuzztest::ElementOf(std::vector{1, 2})', + ) + + def test_tuple_domain(self): + u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) + tup = type_pb2.TypeProto( + type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, u32] + ) + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + d.tuple.elements.add().range.min.bits.bit_count = 32 + d.tuple.elements[0].range.min.bits.data = b'\x00' + d.tuple.elements[0].range.max.bits.bit_count = 32 + d.tuple.elements[0].range.max.bits.data = b'\x0a' + d.tuple.elements.add().arbitrary = True + + self.assertEqual( + jit_wrapper_generator.to_domain(tup, d), + 'fuzztest::TupleOf(fuzztest::InRange(0, 10),' + ' fuzztest::Arbitrary())', + ) + + def test_nested_tuple_domain(self): + u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) + inner_tup = type_pb2.TypeProto( + type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32] + ) + outer_tup = type_pb2.TypeProto( + type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, inner_tup] + ) + + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + d.tuple.elements.add().arbitrary = True + inner_d = d.tuple.elements.add().tuple.elements.add() + inner_d.range.min.bits.bit_count = 32 + inner_d.range.min.bits.data = b'\x00' + inner_d.range.max.bits.bit_count = 32 + inner_d.range.max.bits.data = b'\x05' + + self.assertEqual( + jit_wrapper_generator.to_domain(outer_tup, d), + 'fuzztest::TupleOf(fuzztest::Arbitrary(),' + ' fuzztest::TupleOf(fuzztest::InRange(0, 5)))', + ) + + def test_tuple_with_array_domain(self): + u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) + arr = type_pb2.TypeProto( + type_enum=type_pb2.TypeProto.ARRAY, array_size=3, array_element=u32 + ) + tup = type_pb2.TypeProto( + type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32, arr] + ) + + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + d.tuple.elements.add().arbitrary = True + d.tuple.elements.add().arbitrary = True + + self.assertEqual( + jit_wrapper_generator.to_domain(tup, d), + 'fuzztest::TupleOf(fuzztest::Arbitrary(),' + ' fuzztest::VectorOf(fuzztest::Arbitrary()).WithSize(3))', + ) + + def test_unsupported_domain_raises(self): + u32 = type_pb2.TypeProto(type_enum=type_pb2.TypeProto.BITS, bit_count=32) + tup = type_pb2.TypeProto( + type_enum=type_pb2.TypeProto.TUPLE, tuple_elements=[u32] + ) + d = ir_interface_pb2.PackageInterfaceProto.FuzzTestDomain() + d.range.min.bits.bit_count = 32 + d.range.min.bits.data = b'\x00' + d.range.max.bits.bit_count = 32 + d.range.max.bits.data = b'\x0a' + + with self.assertRaises(Exception): + jit_wrapper_generator.to_domain(tup, d) + + if __name__ == '__main__': absltest.main()