blob: 20db2bb2daeaa02568ef4082c63e594ca5286507 [file] [log] [blame]
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dcf/distributed_comparison_function.h"
#include "dpf/status_macros.h"
namespace distributed_point_functions {
namespace {
void SetToZero(Value& value) {
if (value.value_case() == Value::kInteger) {
value.mutable_integer()->set_value_uint64(0);
} else if (value.value_case() == Value::kIntModN) {
value.mutable_int_mod_n()->set_value_uint64(0);
} else if (value.value_case() == Value::kTuple) {
for (int i = 0; i < value.tuple().elements_size(); ++i) {
SetToZero(*(value.mutable_tuple()->mutable_elements(i)));
}
}
}
} // namespace
DistributedComparisonFunction::DistributedComparisonFunction(
DcfParameters parameters, std::unique_ptr<DistributedPointFunction> dpf)
: parameters_(std::move(parameters)), dpf_(std::move(dpf)) {}
absl::StatusOr<std::unique_ptr<DistributedComparisonFunction>>
DistributedComparisonFunction::Create(const DcfParameters& parameters) {
// A DCF with a single-element domain doesn't make sense.
if (parameters.parameters().log_domain_size() < 1) {
return absl::InvalidArgumentError("A DCF must have log_domain_size >= 1");
}
// We don't support the legacy element_bitsize field in DCFs.
if (!parameters.parameters().has_value_type()) {
return absl::InvalidArgumentError(
"parameters.value_type must be set for "
"DistributedComparisonFunction::Create");
}
// Create parameter vector for the incremental DPF.
std::vector<DpfParameters> dpf_parameters(
parameters.parameters().log_domain_size());
for (int i = 0; i < static_cast<int>(dpf_parameters.size()); ++i) {
dpf_parameters[i].set_log_domain_size(i);
*(dpf_parameters[i].mutable_value_type()) =
parameters.parameters().value_type();
}
// Check that parameters are valid. We can use the DPF proto validator
// directly.
DPF_RETURN_IF_ERROR(
dpf_internal::ProtoValidator::ValidateParameters(dpf_parameters));
// Create incremental DPF.
DPF_ASSIGN_OR_RETURN(
std::unique_ptr<DistributedPointFunction> dpf,
DistributedPointFunction::CreateIncremental(dpf_parameters));
return absl::WrapUnique(
new DistributedComparisonFunction(parameters, std::move(dpf)));
}
absl::StatusOr<std::pair<DcfKey, DcfKey>>
DistributedComparisonFunction::GenerateKeys(absl::uint128 alpha,
const Value& beta) {
const int log_domain_size = parameters_.parameters().log_domain_size();
std::vector<Value> dpf_values(log_domain_size, beta);
for (int i = 0; i < log_domain_size; ++i) {
// beta_i = 0 if alpha_i == 0, and beta otherwise.
bool current_bit =
(alpha & (absl::uint128{1} << (log_domain_size - i - 1))) != 0;
if (!current_bit) {
SetToZero(dpf_values[i]);
}
}
std::pair<DcfKey, DcfKey> result;
DPF_ASSIGN_OR_RETURN(
std::tie(*(result.first.mutable_key()), *(result.second.mutable_key())),
dpf_->GenerateKeysIncremental(
alpha >> 1, // We can ignore the last bit of alpha, since it is
// encoded in dpf_values.back().
dpf_values));
return result;
}
} // namespace distributed_point_functions