blob: 1ce34c7db172be60ff8972f4609e33bdd49fd936 [file] [log] [blame]
/*
* Copyright 2021 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
#define DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_
#include <glog/logging.h>
#include <openssl/cipher.h>
#include <memory>
#include <type_traits>
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_map.h"
#include "absl/meta/type_traits.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "dpf/aes_128_fixed_key_hash.h"
#include "dpf/distributed_point_function.pb.h"
#include "dpf/internal/proto_validator.h"
#include "dpf/internal/value_type_helpers.h"
namespace distributed_point_functions {
// Type trait for all supported types. Used to provide meaningful error messages
// in std::enable_if template guards.
template <typename T>
using is_supported_type = dpf_internal::is_supported_type<T>;
template <typename T>
constexpr bool is_supported_type_v = is_supported_type<T>::value;
// Converts a given Value to the template parameter T.
//
// Returns INVALID_ARGUMENT if the conversion fails.
template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
absl::StatusOr<T> FromValue(const Value& value) {
return dpf_internal::ValueTypeHelper<T>::FromValue(value);
}
// ToValue Converts the argument to a Value.
template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
Value ToValue(const T& input) {
return dpf_internal::ValueTypeHelper<T>::ToValue(input);
}
// ToValueType<T> Returns a `ValueType` message describing T.
template <typename T, typename = absl::enable_if_t<is_supported_type_v<T>>>
ValueType ToValueType() {
return dpf_internal::ValueTypeHelper<T>::ToValueType();
}
// Implements key generation and evaluation of distributed point functions.
// A distributed point function (DPF) is parameterized by an index `alpha` and a
// value `beta`. The key generation procedure produces two keys `k_a`, `k_b`.
// Evaluating each key on any point `x` in the DPF domain results in an additive
// secret share of `beta`, if `x == alpha`, and a share of 0 otherwise. This
// class also supports *incremental* DPFs that can additionally be evaluated on
// prefixes of points, resulting in different values `beta_i`for each prefix of
// `alpha`.
class DistributedPointFunction {
public:
// Creates a new instance of a distributed point function that can be
// evaluated only at the output layer.
//
// Returns INVALID_ARGUMENT if the parameters are invalid.
static absl::StatusOr<std::unique_ptr<DistributedPointFunction>> Create(
const DpfParameters& parameters);
// Creates a new instance of an *incremental* DPF that can be evaluated at
// multiple layers. Each parameter set in `parameters` should specify the
// domain size and element size at one of the layers to be evaluated, in
// increasing domain size order. Element sizes must be non-decreasing.
//
// Returns INVALID_ARGUMENT if the parameters are invalid.
static absl::StatusOr<std::unique_ptr<DistributedPointFunction>>
CreateIncremental(absl::Span<const DpfParameters> parameters);
// DistributedPointFunction is neither copyable nor movable.
DistributedPointFunction(const DistributedPointFunction&) = delete;
DistributedPointFunction& operator=(const DistributedPointFunction&) = delete;
// Converts the argument to a `Value` proto. Also registers the corresponding
// value type with the DPF by calling `RegisterValueType<T>()`.
template <typename T>
absl::StatusOr<Value> ToValue(const T& in) {
absl::Status status = RegisterValueType<T>();
if (!status.ok()) {
return status;
}
return distributed_point_functions::ToValue(in);
}
// Registers the template parameter type with this DPF. Note that it is rarely
// necessary to call this function by hand: It is called by `Create` and
// `CreateIncremental` for all unsigned integer types, including
// absl::uint128, and on every call to ToValue<T>. Only call this function
// when passing `Value`s created by other means than ToValue<T>.
//
// Returns OK on success and otherwise an INTERNAL status describing the
// failure.
template <typename T>
absl::Status RegisterValueType() {
return RegisterValueTypeImpl<T>(value_correction_functions_);
}
// Generates a pair of keys for a DPF that evaluates to `beta` when evaluated
// `alpha`. The type of `beta` must match the ValueType passed in `parameters`
// at construction.
//
// This function provides three overloads: One with `absl::uint128` for
// `beta`, which implies the output type is a simple integer; One with a
// `Value` proto for `beta`, which can be used for all supported value types;
// And a templated version that computes the Value by calling ToValue<T> on
// the argument.
//
// Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf):
//
// // Simple integer:
// dpf->GenerateKeys(23, 42);
//
// // Explicit `Value` proto:
// Value value;
// value[1]->mutable_tuple->add_elements()
// ->mutable_integer->set_value_uint64(12);
// value[1]->mutable_tuple->add_elements()
// ->mutable_integer->set_value_uint64(34);
// // Must be called once before calling GenerateKeys for any type that is
// // not a simple integer. The type should match the one in the
// // DpfParameters passed at construction.
// dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>();
// dpf->GenerateKeys(23, value);
//
// // Templated version (no call to RegisterValueType needed):
// dpf->GenerateKeys(23, Tuple<uint32_t, uint64_t>{12, 34});
//
// Returns INVALID_ARGUMENT if used on an incremental DPF with more
// than one set of parameters, if `alpha` is outside of the domain specified
// at construction, or if `beta` does not match the value type passed at
// construction.
// Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called
// for the type in the `DpfParameters` passed at construction.
// Overload for simple integers.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
absl::uint128 beta) {
return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1));
}
// Overload for explicit Value proto.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
Value beta) {
return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&beta, 1));
}
// Template for automatic conversion to Value proto. Disabled if the argument
// is convertible to `absl::uint128` or `Value` to make overloading
// unambiguous.
template <typename T, typename = absl::enable_if_t<
!std::is_convertible<T, absl::uint128>::value &&
!std::is_convertible<T, Value>::value &&
is_supported_type_v<T>>>
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeys(absl::uint128 alpha,
const T& beta) {
absl::StatusOr<Value> value = ToValue<T>(beta);
if (!value.ok()) {
return value.status();
}
return GenerateKeysIncremental(alpha, absl::MakeConstSpan(&(*value), 1));
}
// Generates a pair of keys for an incremental DPF. For each parameter i
// passed at construction, the DPF evaluates to `beta[i]` at the lowest
// `parameters_[i].log_domain_size()` bits of `alpha`.
//
// Similar to `GenerateKeys`, supports three overloads: One for simple
// integers, passed as an `absl::Span<const absl::uint128>`; One for a span of
// `Value` protos; And a variadic function template that automatically
// converts the passed arguments to a vector of `Value`s.
//
// Example Usages (assuming a std::unique_ptr<DistributedPointFunction> dpf):
//
// // Simple integers:
// std::vector<absl::uint128> beta{123, 456};
// dpf->GenerateKeysIncremental(23, beta);
//
// // Explicit Value protos:
// std::vector<Value> beta(2);
// value[0]->mutable_integer()->set_value_uint128(42);
// value[1]->mutable_tuple->add_elements()
// ->mutable_integer->set_value_uint64(12);
// value[1]->mutable_tuple->add_elements()
// ->mutable_integer->set_value_uint64(34);
// // Must be called once before calling GenerateKeys for any type that is
// // not a simple integer. The type should match the one in the
// // DpfParameters passed at construction.
// dpf->RegisterValueType<Tuple<uint32_t, uint64_t>>();
// dpf->GenerateKeysIncremental(23, beta);
//
// // Templated version (equivalent to the one above):
// dpf->GenerateKeysIncremental(23, 42, Tuple<uint32_t, uint64_t>{12, 34}));
//
// Returns INVALID_ARGUMENT if `beta.size() != parameters_.size()`, if `alpha`
// is outside of the domain specified at construction, or if `beta` does not
// match the element type passed at construction.
// Returns FAILED_PRECONDITION if `RegisterValueType<T>` has not been called
// for all types in the `DpfParameters` passed at construction.
// Overload for simple integers.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
absl::uint128 alpha, absl::Span<const absl::uint128> beta) {
std::vector<Value> values(beta.size());
for (int i = 0; i < static_cast<int>(beta.size()); ++i) {
absl::StatusOr<Value> value = ToValue(beta[i]);
if (!value.ok()) {
return value.status();
}
values[i] = std::move(*value);
}
return GenerateKeysIncremental(alpha, values);
}
// Overload for Value protos.
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
absl::uint128 alpha, absl::Span<const Value> beta);
// Variadic template version. Disabled if the first argument is convertible to
// a span of `absl::uint128`s or `Value`s to make overloading unambiguous.
template <
typename T0, typename... Tn,
typename = absl::enable_if_t<
!std::is_convertible<T0, absl::Span<const Value>>::value &&
!std::is_convertible<T0, absl::Span<const absl::uint128>>::value &&
absl::conjunction<is_supported_type<T0>,
is_supported_type<Tn>...>::value>>
absl::StatusOr<std::pair<DpfKey, DpfKey>> GenerateKeysIncremental(
absl::uint128 alpha, T0&& beta_0, Tn&&... beta_n);
// Returns an `EvaluationContext` for incrementally evaluating the given
// DpfKey.
//
// Returns INVALID_ARGUMENT if `key` doesn't match the parameters given at
// construction.
absl::StatusOr<EvaluationContext> CreateEvaluationContext(DpfKey key) const;
// Evaluates the given `hierarchy_level` of the DPF under all `prefixes`
// passed to this function. If `prefixes` is empty, evaluation starts from the
// seed of `ctx.key`. Otherwise, each element of `prefixes` must fit in the
// domain size of `ctx.previous_hierarchy_level`. Further, `prefixes` may only
// contain extensions of the prefixes passed in the previous call. For
// example, in the following sequence of calls, for each element p2 of
// `prefixes2`, there must be an element p1 of `prefixes1` such that p1 is a
// prefix of p2:
//
// DPF_ASSIGN_OR_RETURN(std::unique_ptr<EvaluationContext> ctx,
// dpf->CreateEvaluationContext(key));
// using T0 = ...;
// DPF_ASSIGN_OR_RETURN(std::vector<T0> evaluations0,
// dpf->EvaluateUntil(0, {}, *ctx));
//
// std::vector<absl::uint128> prefixes1 = ...;
// using T1 = ...;
// DPF_ASSIGN_OR_RETURN(std::vector<T1> evaluations1,
// dpf->EvaluateUntil(1, prefixes1, *ctx));
// ...
// std::vector<absl::uint128> prefixes2 = ...;
// using T2 = ...;
// DPF_ASSIGN_OR_RETURN(std::vector<T2> evaluations2,
// dpf->EvaluateUntil(3, prefixes2, *ctx));
//
// The prefixes are read from the lowest-order bits of the corresponding
// absl::uint128. The number of bits used for each prefix depends on the
// output domain size of the previously evaluated hierarchy level. For
// example, if `ctx` was last evaluated on a hierarchy level with output
// domain size 2**20, then the 20 lowest-order bits of each element in
// `prefixes` are used.
//
// Returns `INVALID_ARGUMENT` if
// - any element of `prefixes` is larger than the next hierarchy level's
// log_domain_size,
// - `prefixes` contains elements that are not extensions of previous
// prefixes, or
// - the bit-size of T doesn't match the next hierarchy level's
// element_bitsize.
template <typename T>
absl::StatusOr<std::vector<T>> EvaluateUntil(
int hierarchy_level, absl::Span<const absl::uint128> prefixes,
EvaluationContext& ctx) const;
template <typename T>
absl::StatusOr<std::vector<T>> EvaluateNext(
absl::Span<const absl::uint128> prefixes, EvaluationContext& ctx) const {
if (prefixes.empty()) {
return EvaluateUntil<T>(0, prefixes, ctx);
} else {
return EvaluateUntil<T>(ctx.previous_hierarchy_level() + 1, prefixes,
ctx);
}
}
// Evaluates a single key at one or multiple points, up to the given
// hierarchy_level. Each element of `evaluation_points` must be within the
// domain of this DPF at `hierarchy_level`.
//
// Example:
//
// DpfKey key = ...;
// std::vector<absl::uint128> evaluation_points = {1, 23, 42};
// // Evaluate `key` on {1, 23, 42}.
// DPF_ASSIGN_OR_RETURN(std::vector<T> result,
// dpf->EvaluateAt(key, 0, evaluation_points);
//
// Returns INVALID_ARGUMENT if `key` is malformed, or if `hierarchy_level` or
// any element of `evaluation_points` is out of range.
template <typename T>
absl::StatusOr<std::vector<T>> EvaluateAt(
const DpfKey& key, int hierarchy_level,
absl::Span<const absl::uint128> evaluation_points) const;
// Returns the DpfParameters of this DPF.
inline absl::Span<const DpfParameters> parameters() const {
return parameters_;
}
private:
// BitVector is a vector of bools. Allows for faster access times than
// std::vector<bool>, as well as inlining if the size is small.
using BitVector =
absl::InlinedVector<bool,
std::max<size_t>(1, sizeof(bool*) / sizeof(bool))>;
// Seeds and control bits resulting from a DPF expansion. This type is
// returned by `ExpandSeeds` and `ExpandAndUpdateContext`.
struct DpfExpansion {
std::vector<absl::uint128> seeds;
BitVector control_bits;
};
// A function for computing value corrections. Used as return type in
// `GetValueCorrectionFunction`.
using ValueCorrectionFunction = absl::StatusOr<std::vector<Value>> (*)(
absl::string_view, absl::string_view, int block_index, const Value&,
bool);
// Private constructor, called by `CreateIncremental`.
DistributedPointFunction(
std::unique_ptr<dpf_internal::ProtoValidator> proto_validator,
std::vector<int> blocks_needed, Aes128FixedKeyHash prg_left,
Aes128FixedKeyHash prg_right, Aes128FixedKeyHash prg_value,
absl::flat_hash_map<std::string, ValueCorrectionFunction>
value_correction_functions);
// Computes the value correction for the given `hierarchy_level`, `seeds`,
// index `alpha` and value `beta`. If `invert` is true, the individual values
// in the returned block are multiplied element-wise by -1. Expands `seeds`
// using `prg_ctx_value_`, then calls the function returned by
// `GetValueCorrectionFunction(parameters_[hierarchy_level])` to obtain the
// value correction words.
//
// Returns multiple values in the case of packing, and a single Value
// otherwise.
//
// Returns INTERNAL in case the PRG expansion fails, and UNIMPLEMENTED if
// `element_bitsize` is not supported.
absl::StatusOr<std::vector<Value>> ComputeValueCorrection(
int hierarchy_level, absl::Span<const absl::uint128> seeds,
absl::uint128 alpha, const Value& beta, bool invert) const;
// Expands the PRG seeds at the next `tree_level` for an incremental DPF with
// index `alpha` and values `beta`, updates `seeds` and `control_bits`, and
// writes the next correction word to `keys`. Called from
// `GenerateKeysIncremental`.
absl::Status GenerateNext(int tree_level, absl::uint128 alpha,
absl::Span<const Value> beta,
absl::Span<absl::uint128> seeds,
absl::Span<bool> control_bits,
absl::Span<DpfKey> keys) const;
// Computes the tree index (representing a path in the FSS tree) from the
// given `domain_index` and `hierarchy_level`. Does NOT check whether the
// given domain index fits in the domain at `hierarchy_level`.
absl::uint128 DomainToTreeIndex(absl::uint128 domain_index,
int hierarchy_level) const;
// Computes the block index (pointing to an element in a batched 128-bit
// block) from the given `domain_index` and `hierarchy_level`. Does NOT check
// whether the given domain index fits in the domain at `hierarchy_level`.
int DomainToBlockIndex(absl::uint128 domain_index, int hierarchy_level) const;
// Performs DPF evaluation of the given `partial_evaluations` using
// prg_ctx_left_ or prg_ctx_right_, and the given `correction_words`. At each
// level `l < correction_words.size()`, the evaluation for the i-th seed in
// `partial_evaluations` continues along the left or right path depending on
// the l-th most significant bit among the lowest `correction_words.size()`
// bits of `paths[i]`.
//
// Returns INVALID_ARGUMENT if the input sizes don't match.
// Returns INTERNAL in case of OpenSSL errors.
absl::StatusOr<DpfExpansion> EvaluateSeeds(
DpfExpansion partial_evaluations, absl::Span<const absl::uint128> paths,
absl::Span<const CorrectionWord* const> correction_words) const;
// Performs DPF expansion of the given `partial_evaluations` using
// prg_ctx_left_ and prg_ctx_right_, and the given `correction_words`. In more
// detail, each of the partial evaluations is subjected to a full subtree
// expansion of `correction_words.size()` levels, and the concatenated result
// is provided in the response. The result contains
// `(partial_evaluations.size() * (2^correction_words.size())` evaluations in
// a single `DpfExpansion`.
//
// Returns INTERNAL in case of OpenSSL errors.
absl::StatusOr<DpfExpansion> ExpandSeeds(
const DpfExpansion& partial_evaluations,
absl::Span<const CorrectionWord* const> correction_words) const;
// Computes partial evaluations of the paths to `prefixes` to be used as the
// starting point of the expansion of `ctx`. If `update_ctx == true`, saves
// the partial evaluations of `ctx.previous_hierarchy_level` to `ctx` and sets
// `ctx.partial_evaluations_level` to `ctx.previous_hierarchy_level`.
// Called by `ExpandAndUpdateContext`.
//
// Returns INVALID_ARGUMENT if any element of `prefixes` is not found in
// `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains
// duplicate seeds.
absl::StatusOr<DpfExpansion> ComputePartialEvaluations(
absl::Span<const absl::uint128> prefixes, bool update_ctx,
EvaluationContext& ctx) const;
// Extracts the seeds for the given `prefixes` from `ctx` and expands them as
// far as needed for the next hierarchy level. Returns the result as a
// `DpfExpansion`. Called by `EvaluateUntil`, where the expanded seeds are
// corrected to obtain output values.
// After expansion, `ctx.hierarchy_level()` is increased. If this isn't the
// last expansion, the expanded seeds are also saved in `ctx` for the next
// expansion.
//
// Returns INVALID_ARGUMENT if any element of `prefixes` is not found in
// `ctx.partial_evaluations()`, or `ctx.partial_evaluations()` contains
// duplicate seeds. Returns INTERNAL in case of OpenSSL errors.
absl::StatusOr<DpfExpansion> ExpandAndUpdateContext(
int hierarchy_level, absl::Span<const absl::uint128> prefixes,
EvaluationContext& ctx) const;
// Compute output PRG value of expanded seeds using prg_ctx_value_.
// Returns blocks_needed_[hierarchy_level] * expansion.seeds.size() blocks,
// where every blocks_needed_[hierarchy_level] correspond to the hash of an
// input seed.
//
// Returns INTERNAL in case of OpenSSL errors.
absl::StatusOr<std::vector<absl::uint128>> HashExpandedSeeds(
int hierarchy_level, absl::Span<const absl::uint128> expansion) const;
// Deterministically serializes the given value_type.
//
// Returns OK on success and INTERNAL in case serialization fails.
static absl::StatusOr<std::string> SerializeValueTypeDeterministically(
const ValueType& value_type);
// Returns the value correction function for the given parameters.
// For all value types except unsigned integers, these functions have to be
// first registered using RegisterValueType<T>.
//
// Returns UNIMPLEMENTED if no matching function was registered.
absl::StatusOr<ValueCorrectionFunction> GetValueCorrectionFunction(
const DpfParameters& parameters) const;
// Static implementation of RegisterValueType<T>, so we can call it from
// `Create`.
template <typename T>
static absl::Status RegisterValueTypeImpl(
absl::flat_hash_map<std::string, ValueCorrectionFunction>&
value_correction_functions);
// Used to validate DpfParameters, DpfKey and EvaluationContext protos.
const std::unique_ptr<dpf_internal::ProtoValidator> proto_validator_;
// DP parameters passed to the factory function. Contains the domain size and
// element size for hierarchy level of the incremental DPF. Owned by
// proto_validator_.
const absl::Span<const DpfParameters> parameters_;
// Number of levels in the evaluation tree. This is always less than or equal
// to the largest log_domain_size in parameters_.
const int tree_levels_needed_;
// Maps levels of the FSS evaluation tree to hierarchy levels (i.e., elements
// of parameters_).
const absl::flat_hash_map<int, int>& tree_to_hierarchy_;
// The inverse of tree_to_hierarchy_.
const std::vector<int>& hierarchy_to_tree_;
// Cached numbers of AES blocks needed for value correction at each hierarchy
// level.
const std::vector<int> blocks_needed_;
// Pseudorandom generator used for seed expansion (left and right), and value
// correction. The PRG G(x) for hierarchy level i is defined as the
// concatenation of
//
// H_left(x), H_right(x), H_value(x + 0), ..., H_value(x + k-1)
//
// where k is equal to blocks_needed_[i], and H_*(x) is the evaluation of
// prg_*_ on input x.
const Aes128FixedKeyHash prg_left_;
const Aes128FixedKeyHash prg_right_;
const Aes128FixedKeyHash prg_value_;
// Maps serialized `ValueType` messages to the correct value correction
// functions. Map values are instantiations of
// `dpf_internal::ComputeValueCorrectionFor`. Relies on protobuf's
// deterministic serialization feature. This has the caveat that messages with
// unknown fields are not supported. However, as long as `ValueType` consists
// of a single `oneof` field, this is fine, since we either know the value
// type and have deterministic serialization because the `ValueType` can only
// contain one field, or we don't know the type and wouldn't be able to
// correct values for it anyway.
absl::flat_hash_map<std::string, ValueCorrectionFunction>
value_correction_functions_;
};
//========================//
// Implementation Details //
//========================//
template <typename T>
absl::Status DistributedPointFunction::RegisterValueTypeImpl(
absl::flat_hash_map<std::string, ValueCorrectionFunction>&
value_correction_functions) {
ValueType value_type = ToValueType<T>();
absl::StatusOr<std::string> serialized_value_type =
SerializeValueTypeDeterministically(value_type);
if (!serialized_value_type.ok()) {
return serialized_value_type.status();
}
value_correction_functions[*serialized_value_type] =
dpf_internal::ComputeValueCorrectionFor<T>;
return absl::OkStatus();
}
template <typename T0, typename... Tn, typename /*= absl::enable_if_t<...>*/>
absl::StatusOr<std::pair<DpfKey, DpfKey>>
DistributedPointFunction::GenerateKeysIncremental(absl::uint128 alpha,
T0&& beta_0, Tn&&... beta_n) {
// Convert the first element of beta. We need to treat it separately to be
// able to check its type in the enable_if above.
absl::StatusOr<Value> value = ToValue(beta_0);
if (!value.ok()) {
return value.status();
}
std::vector<Value> values = {std::move(*value)};
values.reserve(1 + sizeof...(beta_n));
// Convert all values in the parameter pack, stopping at the first error.
absl::Status status = absl::OkStatus();
// We create an unused std::tuple<Tn...> here, because its braced-initializer
// list constructor allows us to operate on beta_n in a well-defined order. In
// C++17, this could be replaced by a fold expression instead.
std::tuple<Tn...>{[this, &status, &values, &value](auto&& beta_i) -> Tn {
if (status.ok()) {
value = this->ToValue(beta_i);
if (value.ok()) {
values.push_back(std::move(*value));
} else {
status = value.status();
}
}
return Tn{};
}(beta_n)...};
// Return if there was an error during conversion, otherwise generate keys.
if (!status.ok()) {
return status;
}
return GenerateKeysIncremental(alpha, values);
}
template <typename T>
absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateUntil(
int hierarchy_level, absl::Span<const absl::uint128> prefixes,
EvaluationContext& ctx) const {
absl::Status status = proto_validator_->ValidateEvaluationContext(ctx);
if (!status.ok()) {
return status;
}
if (hierarchy_level < 0 ||
hierarchy_level >= static_cast<int>(parameters_.size())) {
return absl::InvalidArgumentError(
"`hierarchy_level` must be non-negative and less than "
"parameters_.size()");
}
absl::StatusOr<bool> types_are_equal = dpf_internal::ValueTypesAreEqual(
ToValueType<T>(), parameters_[hierarchy_level].value_type());
if (!types_are_equal.ok()) {
return types_are_equal.status();
} else if (!*types_are_equal) {
return absl::InvalidArgumentError(
"Value type T doesn't match parameters at `hierarchy_level`");
}
if (hierarchy_level <= ctx.previous_hierarchy_level()) {
return absl::InvalidArgumentError(
"`hierarchy_level` must be greater than "
"`ctx.previous_hierarchy_level`");
}
if ((ctx.previous_hierarchy_level() < 0) != (prefixes.empty())) {
return absl::InvalidArgumentError(
"`prefixes` must be empty if and only if this is the first call with "
"`ctx`.");
}
int previous_log_domain_size = 0;
int previous_hierarchy_level = ctx.previous_hierarchy_level();
if (!prefixes.empty()) {
DCHECK(ctx.previous_hierarchy_level() >= 0);
previous_log_domain_size =
parameters_[previous_hierarchy_level].log_domain_size();
for (absl::uint128 prefix : prefixes) {
if (previous_log_domain_size < 128 &&
prefix >= (absl::uint128{1} << previous_log_domain_size)) {
return absl::InvalidArgumentError(
absl::StrFormat("Index %d out of range for hierarchy level %d",
prefix, previous_hierarchy_level));
}
}
}
int64_t prefixes_size = static_cast<int64_t>(prefixes.size());
int log_domain_size = parameters_[hierarchy_level].log_domain_size();
if (log_domain_size - previous_log_domain_size > 62) {
return absl::InvalidArgumentError(
"Output size would be larger than 2**62. Please evaluate fewer "
"hierarchy levels at once.");
}
// The `prefixes` passed in by the caller refer to the domain of the previous
// hierarchy level. However, because we batch multiple elements of type T in a
// single uint128 block, multiple prefixes can actually refer to the same
// block in the FSS evaluation tree. On a high level, our approach is as
// follows:
//
// 1. Split up each element of `prefixes` into a tree index, pointing to a
// block in the FSS tree, and a block index, pointing to an element of type
// T in that block.
//
// 2. Compute a list of unique `tree_indices`, and for each original prefix,
// remember the position of the corresponding tree index in `tree_indices`.
//
// 3. After expanding the unique `tree_indices`, use the positions saved in
// Step (2) together with the corresponding block index to retrieve the
// expanded values for each prefix, and return them in the same order as
// `prefixes`.
//
// `tree_indices` holds the unique tree indices from `prefixes`, to be passed
// to `ExpandAndUpdateContext`.
std::vector<absl::uint128> tree_indices;
tree_indices.reserve(prefixes_size);
// `tree_indices_inverse` is the inverse of `tree_indices`, used for
// deduplicating and constructing `prefix_map`. Use a btree_map because we
// expect `prefixes` (and thus `tree_indices`) to be sorted.
absl::btree_map<absl::uint128, int64_t> tree_indices_inverse;
// `prefix_map` maps each i < prefixes.size() to an element of `tree_indices`
// and a block index. Used to select which elements to return after the
// expansion, to ensure the result is ordered the same way as `prefixes`.
std::vector<std::pair<int64_t, int>> prefix_map;
prefix_map.reserve(prefixes_size);
for (int64_t i = 0; i < prefixes_size; ++i) {
absl::uint128 tree_index =
DomainToTreeIndex(prefixes[i], previous_hierarchy_level);
int block_index = DomainToBlockIndex(prefixes[i], previous_hierarchy_level);
// Check if `tree_index` already exists in `tree_indices`.
auto previous_size = tree_indices_inverse.size();
auto it = tree_indices_inverse.try_emplace(tree_indices_inverse.end(),
tree_index, tree_indices.size());
if (tree_indices_inverse.size() > previous_size) {
tree_indices.push_back(tree_index);
}
prefix_map.push_back(std::make_pair(it->second, block_index));
}
// Perform expansion of unique `tree_indices`.
absl::StatusOr<DpfExpansion> expansion =
ExpandAndUpdateContext(hierarchy_level, tree_indices, ctx);
if (!expansion.ok()) {
return expansion.status();
}
// Hash the expanded seeds.
absl::StatusOr<std::vector<absl::uint128>> hashed_expansion =
HashExpandedSeeds(hierarchy_level, expansion->seeds);
if (!hashed_expansion.ok()) {
return hashed_expansion.status();
}
// Get output correction word from `ctx`.
constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
value_correction =
&(ctx.key()
.correction_words(hierarchy_to_tree_[hierarchy_level])
.value_correction());
} else {
// Last level value correction is stored in an extra proto field, since we
// have one less correction word than tree levels.
value_correction = &(ctx.key().last_level_value_correction());
}
// Split output correction into elements of type T.
absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
dpf_internal::ValuesToArray<T>(*value_correction);
if (!correction_ints.ok()) {
return correction_ints.status();
}
// Compute value corrections for each block in `expanded_seeds`. We have to
// account for the fact that blocks might not be full (i.e., have less than
// elements_per_block elements).
const int corrected_elements_per_block =
1 << (parameters_[hierarchy_level].log_domain_size() -
hierarchy_to_tree_[hierarchy_level]);
const auto expansion_size = static_cast<int64_t>(expansion->seeds.size());
const int blocks_needed = blocks_needed_[hierarchy_level];
DCHECK(corrected_elements_per_block <= elements_per_block);
std::vector<T> corrected_expansion(expansion_size *
corrected_elements_per_block);
for (int64_t i = 0; i < expansion_size; ++i) {
std::array<T, elements_per_block> current_elements =
dpf_internal::ConvertBytesToArrayOf<T>(
absl::string_view(reinterpret_cast<const char*>(
&(*hashed_expansion)[i * blocks_needed]),
blocks_needed * sizeof(absl::uint128)));
for (int j = 0; j < corrected_elements_per_block; ++j) {
if (expansion->control_bits[i]) {
current_elements[j] += (*correction_ints)[j];
}
if (ctx.key().party() == 1) {
current_elements[j] = -current_elements[j];
}
corrected_expansion[i * corrected_elements_per_block + j] =
current_elements[j];
}
}
// Compute the number of outputs we will have. For each prefix, we will have a
// full expansion from the previous heirarchy level to the current heirarchy
// level.
DCHECK(log_domain_size - previous_log_domain_size < 63);
int64_t outputs_per_prefix = int64_t{1}
<< (log_domain_size - previous_log_domain_size);
if (prefixes.empty()) {
// If prefixes is empty (i.e., this is the first evaluation of `ctx`), just
// return the expansion.
DCHECK(static_cast<int>(corrected_expansion.size()) == outputs_per_prefix);
return corrected_expansion;
} else {
// Otherwise, only return elements under `prefixes`.
int blocks_per_tree_prefix = expansion->seeds.size() / tree_indices.size();
std::vector<T> result(prefixes_size * outputs_per_prefix);
for (int64_t i = 0; i < prefixes_size; ++i) {
int64_t prefix_expansion_start =
prefix_map[i].first * blocks_per_tree_prefix *
corrected_elements_per_block +
prefix_map[i].second * outputs_per_prefix;
std::copy_n(&corrected_expansion[prefix_expansion_start],
outputs_per_prefix, &result[i * outputs_per_prefix]);
}
return result;
}
}
template <typename T>
absl::StatusOr<std::vector<T>> DistributedPointFunction::EvaluateAt(
const DpfKey& key, int hierarchy_level,
absl::Span<const absl::uint128> evaluation_points) const {
auto num_evaluation_points = static_cast<int64_t>(evaluation_points.size());
if (hierarchy_level < 0) {
return absl::InvalidArgumentError("`hierarchy_level` must be non-negative");
}
if (hierarchy_level >= static_cast<int>(parameters_.size())) {
return absl::InvalidArgumentError(
"`hierarchy_level` must be less than the number of parameters passed "
"at construction");
}
absl::Status status = proto_validator_->ValidateDpfKey(key);
if (!status.ok()) {
return status;
}
// Get output correction word from `key`.
constexpr int elements_per_block = dpf_internal::ElementsPerBlock<T>();
const ::google::protobuf::RepeatedPtrField<Value>* value_correction = nullptr;
if (hierarchy_level < static_cast<int>(parameters_.size()) - 1) {
value_correction =
&(key.correction_words(hierarchy_to_tree_[hierarchy_level])
.value_correction());
} else {
// Last level value correction is stored in an extra proto field, since we
// have one less correction word than tree levels.
value_correction = &(key.last_level_value_correction());
}
// Split output correction into elements of type T, and save it in
// correction_ints.
absl::StatusOr<std::array<T, elements_per_block>> correction_ints =
dpf_internal::ValuesToArray<T>(*value_correction);
if (!correction_ints.ok()) {
return correction_ints.status();
}
// Split up evaluation_points into tree indices and block indices, if we're
// operating on a packed type. Otherwise set `tree_indices` to
// `evaluation_points`.
std::vector<absl::uint128> maybe_recomputed_tree_indices(0);
absl::Span<const absl::uint128> tree_indices;
if (elements_per_block > 1) {
maybe_recomputed_tree_indices.reserve(num_evaluation_points);
for (int64_t i = 0; i < num_evaluation_points; ++i) {
maybe_recomputed_tree_indices.push_back(
DomainToTreeIndex(evaluation_points[i], hierarchy_level));
}
tree_indices = absl::MakeConstSpan(maybe_recomputed_tree_indices);
} else {
// This avoids copying the evaluation points when elements_per_block == 1.
tree_indices = evaluation_points;
}
// Extract seed and party for DPF evaluation.
absl::uint128 seed = absl::MakeUint128(key.seed().high(), key.seed().low());
bool party = key.party();
DpfExpansion inputs;
inputs.seeds.resize(num_evaluation_points, seed);
inputs.control_bits.resize(num_evaluation_points, party);
// Evaluate DPFs.
const int stop_level = hierarchy_to_tree_[hierarchy_level];
auto correction_words =
absl::MakeConstSpan(key.correction_words()).subspan(0, stop_level);
absl::StatusOr<DpfExpansion> evaluated_inputs =
EvaluateSeeds(std::move(inputs), tree_indices, correction_words);
if (!evaluated_inputs.ok()) {
return evaluated_inputs.status();
}
DCHECK(static_cast<int64_t>(evaluated_inputs->seeds.size()) ==
num_evaluation_points);
// Hash DPF evaluations.
absl::StatusOr<std::vector<absl::uint128>> hashed_expansion =
HashExpandedSeeds(hierarchy_level, evaluated_inputs->seeds);
if (!hashed_expansion.ok()) {
return hashed_expansion.status();
}
// Perform value correction.
std::vector<T> result;
result.reserve(num_evaluation_points);
const int blocks_needed = blocks_needed_[hierarchy_level];
for (int64_t i = 0; i < num_evaluation_points; ++i) {
std::array<T, elements_per_block> current_elements =
dpf_internal::ConvertBytesToArrayOf<T>(
absl::string_view(reinterpret_cast<const char*>(
&(*hashed_expansion)[i * blocks_needed]),
blocks_needed * sizeof(absl::uint128)));
int block_index = 0;
if (elements_per_block > 1) {
block_index = DomainToBlockIndex(evaluation_points[i], hierarchy_level);
}
result.push_back(current_elements[block_index]);
if (evaluated_inputs->control_bits[i]) {
result[i] += (*correction_ints)[block_index];
}
if (party == 1) {
result[i] = -result[i];
}
}
return result;
}
} // namespace distributed_point_functions
#endif // DISTRIBUTED_POINT_FUNCTIONS_DPF_DISTRIBUTED_POINT_FUNCTION_H_