blob: 6368c994a5e804396956c934ee5c981c5cede29e [file] [log] [blame]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLAT_MAP_IMPL_H_
#define FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLAT_MAP_IMPL_H_
#include <cstddef>
#include <optional>
#include <tuple>
#include <type_traits>
#include "absl/random/bit_gen_ref.h"
#include "absl/random/distributions.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "absl/types/span.h"
#include "./fuzztest/internal/domains/domain_base.h"
#include "./fuzztest/internal/domains/serialization_helpers.h"
#include "./fuzztest/internal/logging.h"
#include "./fuzztest/internal/meta.h"
#include "./fuzztest/internal/serialization.h"
#include "./fuzztest/internal/status.h"
#include "./fuzztest/internal/type_support.h"
namespace fuzztest::internal {
// FlatMap takes a domain factory function (flat mapper) and an input domain
// for each parameter of the factory function. The output domain is what the
// flat mapper returns and the domain that FlatMap represents. I.e., the "output
// domain" is re-created dynamically, as it depends on values created by the
// input domains.
template <typename FlatMapper, typename... InputDomain>
using FlatMapOutputDomain = std::decay_t<
std::invoke_result_t<FlatMapper, value_type_t<InputDomain>...>>;
template <typename FlatMapper, typename... InputDomain>
class FlatMapImpl
: public domain_implementor::DomainBase<
FlatMapImpl<FlatMapper, InputDomain...>,
// The user value is the user value of the output domain.
value_type_t<FlatMapOutputDomain<FlatMapper, InputDomain...>>,
// The corpus value is a tuple where the first element is the corpus
// value of the output domain, and the rest is the corpus value of the
// input domains.
std::tuple<
corpus_type_t<FlatMapOutputDomain<FlatMapper, InputDomain...>>,
corpus_type_t<InputDomain>...>> {
public:
using typename FlatMapImpl::DomainBase::corpus_type;
using typename FlatMapImpl::DomainBase::value_type;
FlatMapImpl() = default;
explicit FlatMapImpl(FlatMapper flat_mapper, InputDomain... input_domains)
: flat_mapper_(std::move(flat_mapper)),
input_domains_(std::move(input_domains)...) {}
corpus_type Init(absl::BitGenRef prng) {
if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
auto input_corpus = std::apply(
[&](auto&... input_domains) {
return std::make_tuple(input_domains.Init(prng)...);
},
input_domains_);
auto output_domain = GetOutputDomain(input_corpus);
return std::tuple_cat(std::make_tuple(output_domain.Init(prng)),
input_corpus);
}
void Mutate(corpus_type& val, absl::BitGenRef prng,
const domain_implementor::MutationMetadata& metadata,
bool only_shrink) {
// There is no way to tell whether the current output corpus value is
// consistent with a new output domain generated by mutated inputs, so
// mutating the inputs forces re-initialization of the output domain. This
// means that, when shrinking, we cannot mutate the inputs, as
// re-initializing would lose the "still crashing" output value.
bool mutate_inputs = !only_shrink && absl::Bernoulli(prng, 0.1);
if (mutate_inputs) {
ApplyIndex<kNumInputValues>([&](auto... I) {
// The first field of `val` is the output corpus value, so skip it.
(std::get<I>(input_domains_)
.Mutate(std::get<I + 1>(val), prng, metadata, only_shrink),
...);
});
std::get<0>(val) = GetOutputDomain(val).Init(prng);
return;
}
// For simplicity, we create a new output domain each call to `Mutate`. This
// means that stateful domains don't work, but this is currently a matter of
// convenience, not correctness. For example, `Filter` won't automatically
// find when something is too restrictive.
// TODO(b/246423623): Support stateful domains.
GetOutputDomain(val).Mutate(std::get<0>(val), prng, metadata, only_shrink);
}
value_type GetValue(const corpus_type& v) const {
return GetOutputDomain(v).GetValue(std::get<0>(v));
}
std::optional<corpus_type> FromValue(const value_type&) const {
// We cannot infer the input corpus from the output value, or even determine
// from which output domain the output value came.
return std::nullopt;
}
auto GetPrinter() const {
return FlatMappedPrinter<FlatMapper, InputDomain...>{flat_mapper_,
input_domains_};
}
std::optional<corpus_type> ParseCorpus(const IRObject& obj) const {
auto input_corpus = ParseWithDomainTuple(input_domains_, obj, /*skip=*/1);
if (!input_corpus.has_value()) {
return std::nullopt;
}
absl::Status input_values_validity = ValidateInputValues(*input_corpus);
if (!input_values_validity.ok()) {
absl::FPrintF(GetStderr(), "[!] %s", input_values_validity.message());
return std::nullopt;
}
auto output_domain = GetOutputDomain(*input_corpus);
// We know obj.Subs()[0] exists because ParseWithDomainTuple succeeded.
auto output_corpus = output_domain.ParseCorpus((*obj.Subs())[0]);
if (!output_corpus.has_value()) {
return std::nullopt;
}
return std::tuple_cat(std::make_tuple(*output_corpus), *input_corpus);
}
IRObject SerializeCorpus(const corpus_type& v) const {
auto domain =
std::tuple_cat(std::make_tuple(GetOutputDomain(v)), input_domains_);
return SerializeWithDomainTuple(domain, v);
}
absl::Status ValidateCorpusValue(const corpus_type& corpus_value) const {
// Check input values first.
absl::Status input_values_validity = ValidateInputValues(corpus_value);
if (!input_values_validity.ok()) return input_values_validity;
// Check the output value.
return GetOutputDomain(corpus_value)
.ValidateCorpusValue(std::get<0>(corpus_value));
}
private:
// Returns the output domain for a `tuple` with or without the output value
// as the leading element, and with the input values as the last
// `kNumInputValues` elements.
template <typename Tuple>
FlatMapOutputDomain<FlatMapper, InputDomain...> GetOutputDomain(
const Tuple& tuple) const {
static_assert(is_tuple_v<Tuple> &&
std::tuple_size_v<Tuple> >= kNumInputValues);
static constexpr size_t kOffset =
std::tuple_size_v<Tuple> - kNumInputValues;
return ApplyIndex<kNumInputValues>([&](auto... I) {
// The first field of `tuple` may be the output corpus value, so skip it.
return flat_mapper_(std::get<I>(input_domains_)
.GetValue(std::get<kOffset + I>(tuple))...);
});
}
// Validates the input values for a `tuple` with or without the output value
// as the leading element, and with the input values as the last
// `kNumInputValues` elements.
template <typename Tuple>
absl::Status ValidateInputValues(const Tuple& tuple) const {
static_assert(is_tuple_v<Tuple> &&
std::tuple_size_v<Tuple> >= kNumInputValues);
static constexpr size_t kOffset =
std::tuple_size_v<Tuple> - kNumInputValues;
return ApplyIndex<kNumInputValues>([&](auto... I) {
absl::Status input_values_validity = absl::OkStatus();
(
[&] {
if (!input_values_validity.ok()) return;
const absl::Status s =
std::get<I>(input_domains_)
.ValidateCorpusValue(std::get<kOffset + I>(tuple));
input_values_validity =
Prefix(s, "Invalid value for FlatMap()-ed domain");
}(),
...);
return input_values_validity;
});
}
static constexpr size_t kNumInputValues = sizeof...(InputDomain);
FlatMapper flat_mapper_;
std::tuple<InputDomain...> input_domains_;
};
} // namespace fuzztest::internal
#endif // FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_FLAT_MAP_IMPL_H_