No public description PiperOrigin-RevId: 871419368
diff --git a/fuzztest/internal/domains/smart_pointer_of_impl.h b/fuzztest/internal/domains/smart_pointer_of_impl.h index abbcbb5..43ca646 100644 --- a/fuzztest/internal/domains/smart_pointer_of_impl.h +++ b/fuzztest/internal/domains/smart_pointer_of_impl.h
@@ -16,7 +16,6 @@ #define FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_SMART_POINTER_OF_IMPL_H_ #include <optional> -#include <type_traits> #include <utility> #include <variant> @@ -32,6 +31,7 @@ #include "./fuzztest/internal/printer.h" #include "./fuzztest/internal/serialization.h" #include "./fuzztest/internal/status.h" +#include "./fuzztest/internal/type_support.h" namespace fuzztest::internal { @@ -47,11 +47,7 @@ using typename SmartPointerOfImpl::DomainBase::corpus_type; using typename SmartPointerOfImpl::DomainBase::value_type; - static_assert( - Requires<T>( - [](auto x) -> std::enable_if_t<std::is_pointer_v<decltype(x.get())>, - decltype(!x, *x)> {}), - "T must be a smart pointer type."); + static_assert(is_smart_pointer_v<T>, "T must be a smart pointer type."); // Since we allow for recursion in this domain, we want to delay the // construction of the inner domain. Otherwise we would have an infinite @@ -124,12 +120,7 @@ return; } if (mode == domain_implementor::PrintMode::kSourceCode) { - absl::string_view maker = - is_unique_ptr_v<value_type> ? "std::make_unique" - : is_shared_ptr_v<value_type> ? "std::make_shared" - : "<MAKE_SMART_POINTER>"; - absl::Format(out, "%s<%s>", maker, - GetTypeName<typename value_type::element_type>()); + absl::Format(out, "%s", GetSmartPtrMaker<T>()); } absl::Format(out, "("); PrintValue(inner, std::get<1>(v), out, mode);
diff --git a/fuzztest/internal/meta.h b/fuzztest/internal/meta.h index af4b3ae..25f5838 100644 --- a/fuzztest/internal/meta.h +++ b/fuzztest/internal/meta.h
@@ -143,12 +143,30 @@ inline constexpr bool is_unique_ptr_v<std::unique_ptr<T>> = true; template <typename T> +inline constexpr bool is_unique_ptr_v<const std::unique_ptr<T>> = true; + +template <typename T> inline constexpr bool is_shared_ptr_v = false; template <typename T> inline constexpr bool is_shared_ptr_v<std::shared_ptr<T>> = true; template <typename T> +inline constexpr bool is_shared_ptr_v<const std::shared_ptr<T>> = true; + +template <typename T> +inline constexpr bool is_smart_pointer_v = Requires<T>( + [](auto&& x) + -> std::enable_if_t< + std::is_pointer_v<decltype(x.get())>, + std::void_t<typename std::decay_t<decltype(x)>::element_type, + decltype(!x, *x)>> {}); + +template <typename T> +inline constexpr bool is_complete_type_v = + Requires<T>([](auto x) -> decltype(sizeof(T)) {}); + +template <typename T> inline constexpr bool is_std_complex_v = false; template <typename T>
diff --git a/fuzztest/internal/type_support.h b/fuzztest/internal/type_support.h index eb55471..0273b7e 100644 --- a/fuzztest/internal/type_support.h +++ b/fuzztest/internal/type_support.h
@@ -125,6 +125,14 @@ } template <typename T> +std::enable_if_t<is_smart_pointer_v<T>, std::string> GetSmartPtrMaker() { + absl::string_view maker = is_unique_ptr_v<T> ? "std::make_unique" + : is_shared_ptr_v<T> ? "std::make_shared" + : "<MAKE_SMART_POINTER>"; + return absl::StrCat(maker, "<", GetTypeName<typename T::element_type>(), ">"); +} + +template <typename T> inline constexpr bool has_absl_stringify_v = absl::HasAbslStringify<T>::value; struct IntegralPrinter { @@ -613,6 +621,25 @@ } }; +struct SmartPointerPrinter { + template <typename T> + void PrintUserValue(const T& v, domain_implementor::RawSink out, + domain_implementor::PrintMode mode) { + static_assert(is_smart_pointer_v<T>, "T must be a smart pointer type."); + if (v == nullptr) { + absl::Format(out, "nullptr"); + return; + } + if (mode == domain_implementor::PrintMode::kSourceCode) { + absl::Format(out, "%s", GetSmartPtrMaker<T>()); + } + absl::Format(out, "("); + AutodetectTypePrinter<typename T::element_type>().PrintUserValue(*v, out, + mode); + absl::Format(out, ")"); + } +}; + struct UnknownPrinter { template <typename T> void PrintUserValue(const T& v, domain_implementor::RawSink out, @@ -658,6 +685,8 @@ return CustomPrinter{}; } else if constexpr (is_bindable_aggregate_v<T>) { return AutodetectAggregatePrinter{}; + } else if constexpr (is_smart_pointer_v<T> && is_complete_type_v<T>) { + return SmartPointerPrinter{}; } else { return UnknownPrinter{}; }
diff --git a/fuzztest/internal/type_support_test.cc b/fuzztest/internal/type_support_test.cc index afaa80b..c8bb7de 100644 --- a/fuzztest/internal/type_support_test.cc +++ b/fuzztest/internal/type_support_test.cc
@@ -17,6 +17,7 @@ #include <array> #include <cmath> #include <complex> +#include <cstddef> #include <cstdint> #include <limits> #include <list> @@ -606,5 +607,41 @@ Each("<unprintable value>")); } +TEST(AutodetectTypePrinterTest, DetectsSmartPointers) { + class CustomIntPtr { + public: + using element_type = int; + CustomIntPtr() : n_(std::monostate{}) {} + explicit CustomIntPtr(int n) : n_(n) {} + const int* get() const { return std::get_if<int>(&n_); } + int* get() { return &std::get<int>(n_); } + const int& operator*() const { return *get(); } + int& operator*() { return *get(); } + bool operator!() const { return n_.index() == 0; } + bool operator==(std::nullptr_t) const { return !*this; } + + private: + std::variant<std::monostate, int> n_; + }; + + static_assert(is_smart_pointer_v<std::unique_ptr<int>>); + static_assert(is_smart_pointer_v<const std::unique_ptr<int>>); + static_assert(is_smart_pointer_v<std::shared_ptr<int>>); + static_assert(is_smart_pointer_v<const std::shared_ptr<int>>); + static_assert(is_smart_pointer_v<CustomIntPtr>); + static_assert(is_smart_pointer_v<const CustomIntPtr>); + + EXPECT_THAT(TestPrintValue(std::unique_ptr<int>(nullptr)), Each("nullptr")); + EXPECT_THAT(TestPrintValue(std::shared_ptr<int>(nullptr)), Each("nullptr")); + EXPECT_THAT(TestPrintValue(CustomIntPtr()), Each("nullptr")); + + EXPECT_THAT(TestPrintValue(std::make_unique<int>(7)), + ElementsAre("(7)", "std::make_unique<int>(7)")); + EXPECT_THAT(TestPrintValue(std::make_shared<int>(7)), + ElementsAre("(7)", "std::make_shared<int>(7)")); + EXPECT_THAT(TestPrintValue(CustomIntPtr(7)), + ElementsAre("(7)", "<MAKE_SMART_POINTER><int>(7)")); +} + } // namespace } // namespace fuzztest::internal