blob: 0f3ef7c5dccf7726c0f83a928875667ad4163f47 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/aggregation_service/aggregation_service_test_utils.h"
#include <stddef.h>
#include <stdint.h>
#include <optional>
#include <ostream>
#include <string>
#include <tuple>
#include <vector>
#include "base/base64.h"
#include "base/check.h"
#include "base/containers/contains.h"
#include "base/containers/span.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/json/json_reader.h"
#include "base/not_fatal_until.h"
#include "base/strings/strcat.h"
#include "base/task/thread_pool.h"
#include "base/threading/sequence_bound.h"
#include "base/time/clock.h"
#include "base/time/time.h"
#include "base/types/expected_macros.h"
#include "base/uuid.h"
#include "base/values.h"
#include "content/browser/aggregation_service/aggregatable_report.h"
#include "content/browser/aggregation_service/aggregation_service_storage.h"
#include "content/browser/aggregation_service/aggregation_service_storage_sql.h"
#include "content/browser/aggregation_service/public_key.h"
#include "content/browser/aggregation_service/public_key_parsing_utils.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/mojom/aggregation_service/aggregatable_report.mojom.h"
#include "third_party/boringssl/src/include/openssl/hpke.h"
#include "url/gurl.h"
#include "url/origin.h"
namespace content {
namespace aggregation_service {
testing::AssertionResult PublicKeysEqual(const std::vector<PublicKey>& expected,
const std::vector<PublicKey>& actual) {
const auto tie = [](const PublicKey& key) {
return std::make_tuple(key.id, key.key);
};
if (expected.size() != actual.size()) {
return testing::AssertionFailure() << "Expected length " << expected.size()
<< ", actual: " << actual.size();
}
for (size_t i = 0; i < expected.size(); i++) {
if (tie(expected[i]) != tie(actual[i])) {
return testing::AssertionFailure()
<< "Expected " << expected[i] << " at index " << i
<< ", actual: " << actual[i];
}
}
return testing::AssertionSuccess();
}
using AggregationServicePayload = AggregatableReport::AggregationServicePayload;
testing::AssertionResult AggregatableReportsEqual(
const AggregatableReport& expected,
const AggregatableReport& actual) {
if (expected.payloads().size() != actual.payloads().size()) {
return testing::AssertionFailure()
<< "Expected payloads size " << expected.payloads().size()
<< ", actual: " << actual.payloads().size();
}
for (size_t i = 0; i < expected.payloads().size(); ++i) {
const AggregationServicePayload& expected_payload = expected.payloads()[i];
const AggregationServicePayload& actual_payload = actual.payloads()[i];
if (expected_payload.payload != actual_payload.payload) {
return testing::AssertionFailure()
<< "Expected payloads at payload index " << i << " to match";
}
if (expected_payload.key_id != actual_payload.key_id) {
return testing::AssertionFailure()
<< "Expected key_id " << expected_payload.key_id
<< " at payload index " << i
<< ", actual: " << actual_payload.key_id;
}
}
if (expected.shared_info() != actual.shared_info()) {
return testing::AssertionFailure()
<< "Expected shared info " << expected.shared_info()
<< ", actual: " << actual.shared_info();
}
if (expected.additional_fields() != actual.additional_fields()) {
return testing::AssertionFailure() << "Expected additional fields to match";
}
return testing::AssertionSuccess();
}
testing::AssertionResult ReportRequestsEqual(
const AggregatableReportRequest& expected,
const AggregatableReportRequest& actual) {
if (expected.processing_urls().size() != actual.processing_urls().size()) {
return testing::AssertionFailure()
<< "Expected processing_urls size "
<< expected.processing_urls().size()
<< ", actual: " << actual.processing_urls().size();
}
for (size_t i = 0; i < expected.processing_urls().size(); ++i) {
if (expected.processing_urls()[i] != actual.processing_urls()[i]) {
return testing::AssertionFailure()
<< "Expected processing_urls()[" << i << "] to be "
<< expected.processing_urls()[i]
<< ", actual: " << actual.processing_urls()[i];
}
}
testing::AssertionResult payload_contents_equal = PayloadContentsEqual(
expected.payload_contents(), actual.payload_contents());
if (!payload_contents_equal)
return payload_contents_equal;
if (expected.reporting_path() != actual.reporting_path()) {
return testing::AssertionFailure()
<< "Expected reporting_path " << expected.reporting_path()
<< ", actual: " << actual.reporting_path();
}
if (expected.additional_fields() != actual.additional_fields()) {
return testing::AssertionFailure() << "Expected additional fields to match";
}
return SharedInfoEqual(expected.shared_info(), actual.shared_info());
}
testing::AssertionResult PayloadContentsEqual(
const AggregationServicePayloadContents& expected,
const AggregationServicePayloadContents& actual) {
if (expected.operation != actual.operation) {
return testing::AssertionFailure()
<< "Expected operation " << expected.operation
<< ", actual: " << actual.operation;
}
if (expected.contributions.size() != actual.contributions.size()) {
return testing::AssertionFailure()
<< "Expected contributions.size() " << expected.contributions.size()
<< ", actual: " << actual.contributions.size();
}
for (size_t i = 0; i < expected.contributions.size(); ++i) {
if (expected.contributions[i].bucket != actual.contributions[i].bucket) {
return testing::AssertionFailure()
<< "Expected contribution " << i << " bucket "
<< expected.contributions[i].bucket
<< ", actual: " << actual.contributions[i].bucket;
}
if (expected.contributions[i].value != actual.contributions[i].value) {
return testing::AssertionFailure()
<< "Expected contribution " << i << " value "
<< expected.contributions[i].value
<< ", actual: " << actual.contributions[i].value;
}
}
if (expected.aggregation_mode != actual.aggregation_mode) {
return testing::AssertionFailure()
<< "Expected aggregation_mode " << expected.aggregation_mode
<< ", actual: " << actual.aggregation_mode;
}
if (expected.aggregation_coordinator_origin !=
actual.aggregation_coordinator_origin) {
return testing::AssertionFailure()
<< "Expected aggregation_coordinator_origin "
<< (expected.aggregation_coordinator_origin
? expected.aggregation_coordinator_origin->Serialize()
: "null")
<< ", actual: "
<< (actual.aggregation_coordinator_origin
? actual.aggregation_coordinator_origin->Serialize()
: "null");
}
return testing::AssertionSuccess();
}
testing::AssertionResult SharedInfoEqual(
const AggregatableReportSharedInfo& expected,
const AggregatableReportSharedInfo& actual) {
if (expected.scheduled_report_time != actual.scheduled_report_time) {
return testing::AssertionFailure()
<< "Expected scheduled_report_time "
<< expected.scheduled_report_time
<< ", actual: " << actual.scheduled_report_time;
}
if (expected.report_id != actual.report_id) {
return testing::AssertionFailure()
<< "Expected report_id " << expected.report_id
<< ", actual: " << actual.report_id;
}
if (expected.reporting_origin != actual.reporting_origin) {
return testing::AssertionFailure()
<< "Expected reporting_origin " << expected.reporting_origin
<< ", actual: " << actual.reporting_origin;
}
if (expected.debug_mode != actual.debug_mode) {
return testing::AssertionFailure()
<< "Expected debug_mode " << expected.debug_mode
<< ", actual: " << actual.debug_mode;
}
if (expected.additional_fields != actual.additional_fields) {
return testing::AssertionFailure()
<< "Expected additional_fields " << expected.additional_fields
<< ", actual: " << actual.additional_fields;
}
if (expected.api_version != actual.api_version) {
return testing::AssertionFailure()
<< "Expected api_version " << expected.api_version
<< ", actual: " << actual.api_version;
}
if (expected.api_identifier != actual.api_identifier) {
return testing::AssertionFailure()
<< "Expected api_identifier " << expected.api_identifier
<< ", actual: " << actual.api_identifier;
}
return testing::AssertionSuccess();
}
AggregatableReportRequest CreateExampleRequest(
blink::mojom::AggregationServiceMode aggregation_mode,
int failed_send_attempts,
std::optional<url::Origin> aggregation_coordinator_origin) {
return CreateExampleRequestWithReportTime(
/*report_time=*/base::Time::Now(), aggregation_mode, failed_send_attempts,
std::move(aggregation_coordinator_origin));
}
AggregatableReportRequest CreateExampleRequestWithReportTime(
base::Time report_time,
blink::mojom::AggregationServiceMode aggregation_mode,
int failed_send_attempts,
std::optional<url::Origin> aggregation_coordinator_origin) {
return AggregatableReportRequest::Create(
AggregationServicePayloadContents(
AggregationServicePayloadContents::Operation::kHistogram,
{blink::mojom::AggregatableReportHistogramContribution(
/*bucket=*/123, /*value=*/456,
/*filtering_id=*/std::nullopt)},
aggregation_mode, std::move(aggregation_coordinator_origin),
/*max_contributions_allowed=*/20,
/*filtering_id_max_bytes=*/std::nullopt),
AggregatableReportSharedInfo(
/*scheduled_report_time=*/report_time,
/*report_id=*/
base::Uuid::GenerateRandomV4(),
url::Origin::Create(GURL("https://reporting.example")),
AggregatableReportSharedInfo::DebugMode::kDisabled,
/*additional_fields=*/base::Value::Dict(),
/*api_version=*/"",
/*api_identifier=*/"example-api"),
/*reporting_path=*/"example-path",
/*debug_key=*/std::nullopt, /*additional_fields=*/{},
failed_send_attempts)
.value();
}
AggregatableReportRequest CloneReportRequest(
const AggregatableReportRequest& request) {
return AggregatableReportRequest::CreateForTesting(
request.processing_urls(), request.payload_contents(),
request.shared_info().Clone(), request.reporting_path(),
request.debug_key(), request.additional_fields(),
request.failed_send_attempts())
.value();
}
AggregatableReport CloneAggregatableReport(const AggregatableReport& report) {
std::vector<AggregationServicePayload> payloads;
for (const AggregationServicePayload& payload : report.payloads()) {
payloads.emplace_back(payload.payload, payload.key_id,
payload.debug_cleartext_payload);
}
return AggregatableReport(std::move(payloads), report.shared_info(),
report.debug_key(), report.additional_fields(),
report.aggregation_coordinator_origin());
}
TestHpkeKey::TestHpkeKey(std::string key_id) : key_id_(std::move(key_id)) {
EXPECT_TRUE(EVP_HPKE_KEY_generate(full_hpke_key_.get(),
EVP_hpke_x25519_hkdf_sha256()));
}
TestHpkeKey::TestHpkeKey(TestHpkeKey&&) = default;
TestHpkeKey& TestHpkeKey::operator=(TestHpkeKey&&) = default;
TestHpkeKey::~TestHpkeKey() = default;
PublicKey TestHpkeKey::GetPublicKey() const {
std::vector<uint8_t> public_key(X25519_PUBLIC_VALUE_LEN);
size_t public_key_len;
EXPECT_TRUE(EVP_HPKE_KEY_public_key(
/*key=*/full_hpke_key_.get(), /*out=*/public_key.data(),
/*out_len=*/&public_key_len, /*max_out=*/public_key.size()));
EXPECT_EQ(public_key.size(), public_key_len);
return PublicKey(key_id_, std::move(public_key));
}
std::string TestHpkeKey::GetPublicKeyBase64() const {
return base::Base64Encode(GetPublicKey().key);
}
base::expected<PublicKeyset, std::string> ReadAndParsePublicKeys(
const base::FilePath& file,
base::Time now) {
if (!base::PathExists(file)) {
return base::unexpected("Failed to open file: " + file.MaybeAsASCII());
}
std::string contents;
if (!base::ReadFileToString(file, &contents)) {
return base::unexpected("Failed to read file: " + file.MaybeAsASCII());
}
ASSIGN_OR_RETURN(
base::Value value,
base::JSONReader::ReadAndReturnValueWithError(contents),
[&](base::JSONReader::Error error) {
return base::StrCat({"Failed to parse \"", contents,
"\" as JSON: ", std::move(error).message});
});
std::vector<PublicKey> keys = GetPublicKeys(value);
if (keys.empty()) {
return base::unexpected(
base::StrCat({"Failed to parse public keys from \"", contents, "\""}));
}
return PublicKeyset(std::move(keys), /*fetch_time=*/now,
/*expiry_time=*/base::Time::Max());
}
std::vector<uint8_t> DecryptPayloadWithHpke(
base::span<const uint8_t> payload,
const EVP_HPKE_KEY& key,
const std::string& expected_serialized_shared_info) {
base::span<const uint8_t> enc = payload.first<X25519_PUBLIC_VALUE_LEN>();
std::string authenticated_info_str =
base::StrCat({AggregatableReport::kDomainSeparationPrefix,
expected_serialized_shared_info});
base::span<const uint8_t> authenticated_info =
base::as_bytes(base::make_span(authenticated_info_str));
// No null terminators should have been copied when concatenating the strings.
CHECK(!base::Contains(authenticated_info_str, '\0'));
bssl::ScopedEVP_HPKE_CTX recipient_context;
if (!EVP_HPKE_CTX_setup_recipient(
/*ctx=*/recipient_context.get(), /*key=*/&key,
/*kdf=*/EVP_hpke_hkdf_sha256(),
/*aead=*/EVP_hpke_chacha20_poly1305(),
/*enc=*/enc.data(), /*enc_len=*/enc.size(),
/*info=*/authenticated_info.data(),
/*info_len=*/authenticated_info.size())) {
return {};
}
base::span<const uint8_t> ciphertext =
payload.subspan(X25519_PUBLIC_VALUE_LEN);
std::vector<uint8_t> plaintext(ciphertext.size());
size_t plaintext_len;
if (!EVP_HPKE_CTX_open(
/*ctx=*/recipient_context.get(), /*out=*/plaintext.data(),
/*out_len*/ &plaintext_len, /*max_out_len=*/plaintext.size(),
/*in=*/ciphertext.data(), /*in_len=*/ciphertext.size(),
/*ad=*/nullptr,
/*ad_len=*/0)) {
return {};
}
plaintext.resize(plaintext_len);
return plaintext;
}
} // namespace aggregation_service
TestAggregationServiceStorageContext::TestAggregationServiceStorageContext(
const base::Clock* clock)
: storage_(base::SequenceBound<AggregationServiceStorageSql>(
base::ThreadPool::CreateSequencedTaskRunner({base::MayBlock()}),
/*run_in_memory=*/true,
/*path_to_database=*/base::FilePath(),
clock)) {}
TestAggregationServiceStorageContext::~TestAggregationServiceStorageContext() =
default;
const base::SequenceBound<content::AggregationServiceStorage>&
TestAggregationServiceStorageContext::GetStorage() {
return storage_;
}
MockAggregationService::MockAggregationService() = default;
MockAggregationService::~MockAggregationService() = default;
void MockAggregationService::AddObserver(AggregationServiceObserver* observer) {
observers_.AddObserver(observer);
}
void MockAggregationService::RemoveObserver(
AggregationServiceObserver* observer) {
observers_.RemoveObserver(observer);
}
void MockAggregationService::NotifyRequestStorageModified() {
for (auto& observer : observers_) {
observer.OnRequestStorageModified();
}
}
void MockAggregationService::NotifyReportHandled(
const AggregatableReportRequest& request,
std::optional<AggregationServiceStorage::RequestId> id,
std::optional<AggregatableReport> report,
base::Time report_handled_time,
AggregationServiceObserver::ReportStatus status) {
for (auto& observer : observers_)
observer.OnReportHandled(request, id, report, report_handled_time, status);
}
AggregatableReportRequestsAndIdsBuilder::
AggregatableReportRequestsAndIdsBuilder() = default;
AggregatableReportRequestsAndIdsBuilder::
~AggregatableReportRequestsAndIdsBuilder() = default;
AggregatableReportRequestsAndIdsBuilder&&
AggregatableReportRequestsAndIdsBuilder::AddRequestWithID(
AggregatableReportRequest request,
AggregationServiceStorage::RequestId id) && {
requests_.push_back(AggregationServiceStorage::RequestAndId({
.request = std::move(request),
.id = id,
}));
return std::move(*this);
}
std::vector<AggregationServiceStorage::RequestAndId>
AggregatableReportRequestsAndIdsBuilder::Build() && {
return std::move(requests_);
}
std::ostream& operator<<(
std::ostream& out,
AggregationServicePayloadContents::Operation operation) {
switch (operation) {
case AggregationServicePayloadContents::Operation::kHistogram:
return out << "kHistogram";
}
}
std::ostream& operator<<(
std::ostream& out,
blink::mojom::AggregationServiceMode aggregation_mode) {
switch (aggregation_mode) {
case blink::mojom::AggregationServiceMode::kTeeBased:
return out << "kTeeBased";
case blink::mojom::AggregationServiceMode::kExperimentalPoplar:
return out << "kExperimentalPoplar";
}
}
std::ostream& operator<<(std::ostream& out,
AggregatableReportSharedInfo::DebugMode debug_mode) {
switch (debug_mode) {
case AggregatableReportSharedInfo::DebugMode::kDisabled:
return out << "kDisabled";
case AggregatableReportSharedInfo::DebugMode::kEnabled:
return out << "kEnabled";
}
}
bool operator==(const PublicKey& a, const PublicKey& b) {
const auto tie = [](const PublicKey& public_key) {
return std::make_tuple(public_key.id, public_key.key);
};
return tie(a) == tie(b);
}
bool operator==(const AggregatableReport& a, const AggregatableReport& b) {
return aggregation_service::AggregatableReportsEqual(a, b);
}
} // namespace content