blob: 4a40fd11afff629d680ec3deca8d784ee133d531 [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/aggregation_service/aggregation_service_test_utils.h"
#include <stddef.h>
#include <stdint.h>
#include <ostream>
#include <string>
#include <tuple>
#include <vector>
#include "base/base64.h"
#include "base/containers/contains.h"
#include "base/task/thread_pool.h"
#include "base/threading/sequence_bound.h"
#include "base/time/time.h"
#include "content/browser/aggregation_service/aggregatable_report.h"
#include "content/browser/aggregation_service/aggregation_service_storage_sql.h"
#include "content/browser/aggregation_service/public_key.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/boringssl/src/include/openssl/hpke.h"
#include "url/gurl.h"
#include "url/origin.h"
namespace content {
namespace aggregation_service {
testing::AssertionResult PublicKeysEqual(const std::vector<PublicKey>& expected,
const std::vector<PublicKey>& actual) {
const auto tie = [](const PublicKey& key) {
return std::make_tuple(key.id, key.key);
};
if (expected.size() != actual.size()) {
return testing::AssertionFailure() << "Expected length " << expected.size()
<< ", actual: " << actual.size();
}
for (size_t i = 0; i < expected.size(); i++) {
if (tie(expected[i]) != tie(actual[i])) {
return testing::AssertionFailure()
<< "Expected " << expected[i] << " at index " << i
<< ", actual: " << actual[i];
}
}
return testing::AssertionSuccess();
}
using AggregationServicePayload = AggregatableReport::AggregationServicePayload;
testing::AssertionResult AggregatableReportsEqual(
const AggregatableReport& expected,
const AggregatableReport& actual) {
if (expected.payloads().size() != actual.payloads().size()) {
return testing::AssertionFailure()
<< "Expected payloads size " << expected.payloads().size()
<< ", actual: " << actual.payloads().size();
}
for (size_t i = 0; i < expected.payloads().size(); ++i) {
const AggregationServicePayload& expected_payload = expected.payloads()[i];
const AggregationServicePayload& actual_payload = actual.payloads()[i];
if (expected_payload.origin != actual_payload.origin) {
return testing::AssertionFailure()
<< "Expected origin " << expected_payload.origin
<< " at payload index " << i
<< ", actual: " << actual_payload.origin;
}
if (expected_payload.payload != actual_payload.payload) {
return testing::AssertionFailure()
<< "Expected payloads at payload index " << i << " to match";
}
if (expected_payload.key_id != actual_payload.key_id) {
return testing::AssertionFailure()
<< "Expected key_id " << expected_payload.key_id
<< " at payload index " << i
<< ", actual: " << actual_payload.key_id;
}
}
return SharedInfoEqual(expected.shared_info(), actual.shared_info());
}
testing::AssertionResult ReportRequestsEqual(
const AggregatableReportRequest& expected,
const AggregatableReportRequest& actual) {
if (expected.processing_origins().size() !=
actual.processing_origins().size()) {
return testing::AssertionFailure()
<< "Expected processing_origins size "
<< expected.processing_origins().size()
<< ", actual: " << actual.processing_origins().size();
}
for (size_t i = 0; i < expected.processing_origins().size(); ++i) {
if (expected.processing_origins()[i] != actual.processing_origins()[i]) {
return testing::AssertionFailure()
<< "Expected processing_origins()[" << i << "] to be "
<< expected.processing_origins()[i]
<< ", actual: " << actual.processing_origins()[i];
}
}
testing::AssertionResult payload_contents_equal = PayloadContentsEqual(
expected.payload_contents(), actual.payload_contents());
if (!payload_contents_equal)
return payload_contents_equal;
return SharedInfoEqual(expected.shared_info(), actual.shared_info());
}
testing::AssertionResult PayloadContentsEqual(
const AggregationServicePayloadContents& expected,
const AggregationServicePayloadContents& actual) {
if (expected.operation != actual.operation) {
return testing::AssertionFailure()
<< "Expected operation " << expected.operation
<< ", actual: " << actual.operation;
}
if (expected.bucket != actual.bucket) {
return testing::AssertionFailure() << "Expected bucket " << expected.bucket
<< ", actual: " << actual.bucket;
}
if (expected.value != actual.value) {
return testing::AssertionFailure() << "Expected value " << expected.value
<< ", actual: " << actual.value;
}
if (expected.processing_type != actual.processing_type) {
return testing::AssertionFailure()
<< "Expected processing_type " << expected.processing_type
<< ", actual: " << actual.processing_type;
}
return testing::AssertionSuccess();
}
testing::AssertionResult SharedInfoEqual(
const AggregatableReportSharedInfo& expected,
const AggregatableReportSharedInfo& actual) {
if (expected.scheduled_report_time != actual.scheduled_report_time) {
return testing::AssertionFailure()
<< "Expected scheduled_report_time "
<< expected.scheduled_report_time
<< ", actual: " << actual.scheduled_report_time;
}
if (expected.privacy_budget_key != actual.privacy_budget_key) {
return testing::AssertionFailure()
<< "Expected privacy_budget_key " << expected.privacy_budget_key
<< ", actual: " << actual.privacy_budget_key;
}
return testing::AssertionSuccess();
}
std::vector<url::Origin> GetExampleProcessingOrigins() {
return {url::Origin::Create(GURL("https://a.example")),
url::Origin::Create(GURL("https://b.example"))};
}
AggregatableReportRequest CreateExampleRequest(
AggregationServicePayloadContents::ProcessingType processing_type) {
return CreateExampleRequest(processing_type, GetExampleProcessingOrigins());
}
AggregatableReportRequest CreateExampleRequest(
AggregationServicePayloadContents::ProcessingType processing_type,
std::vector<url::Origin> processing_origins) {
return AggregatableReportRequest::Create(
std::move(processing_origins),
AggregationServicePayloadContents(
AggregationServicePayloadContents::Operation::
kHierarchicalHistogram,
/*bucket=*/123, /*value=*/456, processing_type,
url::Origin::Create(GURL("https://reporting.example"))),
AggregatableReportSharedInfo(
/*scheduled_report_time=*/base::Time::Now(),
/*privacy_budget_key=*/"example_budget_key"))
.value();
}
AggregatableReportRequest CloneReportRequest(
const AggregatableReportRequest& request) {
return AggregatableReportRequest::Create(request.processing_origins(),
request.payload_contents(),
request.shared_info())
.value();
}
AggregatableReport CloneAggregatableReport(const AggregatableReport& report) {
std::vector<AggregationServicePayload> payloads;
for (const AggregationServicePayload& payload : report.payloads()) {
payloads.emplace_back(payload.origin, payload.payload, payload.key_id);
}
return AggregatableReport(std::move(payloads), report.shared_info());
}
TestHpkeKey GenerateKey(std::string key_id) {
bssl::ScopedEVP_HPKE_KEY key;
EXPECT_TRUE(EVP_HPKE_KEY_generate(key.get(), EVP_hpke_x25519_hkdf_sha256()));
std::vector<uint8_t> public_key(X25519_PUBLIC_VALUE_LEN);
size_t public_key_len;
EXPECT_TRUE(EVP_HPKE_KEY_public_key(
/*key=*/key.get(), /*out=*/public_key.data(),
/*out_len=*/&public_key_len, /*max_out=*/public_key.size()));
EXPECT_EQ(public_key.size(), public_key_len);
TestHpkeKey hpke_key{
{}, PublicKey(key_id, public_key), base::Base64Encode(public_key)};
EVP_HPKE_KEY_copy(&hpke_key.full_hpke_key, key.get());
return hpke_key;
}
} // namespace aggregation_service
TestAggregationServiceStorageContext::TestAggregationServiceStorageContext(
const base::Clock* clock)
: storage_(base::SequenceBound<AggregationServiceStorageSql>(
base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()}),
/*run_in_memory=*/true,
/*path_to_database=*/base::FilePath(),
clock)) {}
TestAggregationServiceStorageContext::~TestAggregationServiceStorageContext() =
default;
const base::SequenceBound<content::AggregationServiceKeyStorage>&
TestAggregationServiceStorageContext::GetKeyStorage() {
return storage_;
}
TestAggregationServiceKeyFetcher::TestAggregationServiceKeyFetcher()
: AggregationServiceKeyFetcher(/*storage_context=*/nullptr,
/*network_fetcher=*/nullptr) {}
TestAggregationServiceKeyFetcher::~TestAggregationServiceKeyFetcher() = default;
void TestAggregationServiceKeyFetcher::GetPublicKey(const url::Origin& origin,
FetchCallback callback) {
callbacks_[origin].push_back(std::move(callback));
}
void TestAggregationServiceKeyFetcher::TriggerPublicKeyResponse(
const url::Origin& origin,
absl::optional<PublicKey> key,
PublicKeyFetchStatus status) {
ASSERT_TRUE(base::Contains(callbacks_, origin))
<< "No corresponding GetPublicKeys call for origin " << origin;
ASSERT_EQ(key.has_value(), status == PublicKeyFetchStatus::kOk)
<< "Key must be returned if and only if status is kOk";
std::vector<FetchCallback> callbacks = std::move(callbacks_[origin]);
callbacks_.erase(origin);
for (FetchCallback& callback : callbacks) {
std::move(callback).Run(key, status);
}
}
void TestAggregationServiceKeyFetcher::TriggerPublicKeyResponseForAllOrigins(
absl::optional<PublicKey> key,
PublicKeyFetchStatus status) {
std::vector<url::Origin> all_origins_;
for (const auto& elem : callbacks_) {
all_origins_.push_back(elem.first);
}
for (auto& origin : all_origins_) {
TriggerPublicKeyResponse(std::move(origin), key, status);
}
}
bool TestAggregationServiceKeyFetcher::HasPendingCallbacks() {
return !callbacks_.empty();
}
TestAggregatableReportProvider::TestAggregatableReportProvider() = default;
TestAggregatableReportProvider::~TestAggregatableReportProvider() = default;
absl::optional<AggregatableReport>
TestAggregatableReportProvider::CreateFromRequestAndPublicKeys(
AggregatableReportRequest report_request,
std::vector<PublicKey> public_keys) const {
++num_calls_;
previous_request_ = aggregation_service::CloneReportRequest(report_request);
previous_public_keys_ = public_keys;
EXPECT_TRUE(report_to_return_.has_value());
return aggregation_service::CloneAggregatableReport(
report_to_return_.value());
}
std::ostream& operator<<(
std::ostream& out,
const AggregationServicePayloadContents::Operation& operation) {
switch (operation) {
case AggregationServicePayloadContents::Operation::kHierarchicalHistogram:
return out << "kHierarchicalHistogram";
}
}
std::ostream& operator<<(
std::ostream& out,
const AggregationServicePayloadContents::ProcessingType& processing_type) {
switch (processing_type) {
case AggregationServicePayloadContents::ProcessingType::kTwoParty:
return out << "kTwoParty";
case AggregationServicePayloadContents::ProcessingType::kSingleServer:
return out << "kSingleServer";
}
}
} // namespace content