blob: 4db281978c7816820bee1ce13824ef8f333cd1cc [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/aggregation_service/aggregation_service_impl.h"
#include <stdint.h>
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "base/callback.h"
#include "base/callback_helpers.h"
#include "base/containers/contains.h"
#include "base/containers/flat_map.h"
#include "base/files/scoped_temp_dir.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/test/bind.h"
#include "base/time/time.h"
#include "content/browser/aggregation_service/aggregatable_report.h"
#include "content/browser/aggregation_service/aggregatable_report_assembler.h"
#include "content/browser/aggregation_service/aggregatable_report_scheduler.h"
#include "content/browser/aggregation_service/aggregatable_report_sender.h"
#include "content/browser/aggregation_service/aggregation_service_storage.h"
#include "content/browser/aggregation_service/aggregation_service_test_utils.h"
#include "content/public/test/browser_task_environment.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "url/gurl.h"
namespace content {
// TODO(alexmt): Consider rewriting these tests using gmock.
class TestAggregatableReportAssembler : public AggregatableReportAssembler {
public:
explicit TestAggregatableReportAssembler(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
: AggregatableReportAssembler(
/*storage_context=*/nullptr,
std::move(url_loader_factory)) {}
~TestAggregatableReportAssembler() override = default;
void AssembleReport(AggregatableReportRequest request,
AssemblyCallback callback) override {
callbacks_.emplace(unique_id_counter_++, std::move(callback));
}
void TriggerResponse(int64_t report_id,
absl::optional<AggregatableReport> report,
AssemblyStatus status) {
ASSERT_TRUE(base::Contains(callbacks_, report_id));
ASSERT_EQ(report.has_value(), status == AssemblyStatus::kOk);
std::move(callbacks_[report_id]).Run(std::move(report), status);
callbacks_.erase(report_id);
}
private:
int64_t unique_id_counter_ = 0;
std::map<int64_t, AssemblyCallback> callbacks_;
};
class TestAggregatableReportSender : public AggregatableReportSender {
public:
explicit TestAggregatableReportSender(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
: AggregatableReportSender(std::move(url_loader_factory)) {}
~TestAggregatableReportSender() override = default;
void SendReport(const GURL& url,
const base::Value& contents,
ReportSentCallback callback) override {
callbacks_.emplace(unique_id_counter_++, std::move(callback));
}
void TriggerResponse(int64_t report_id, RequestStatus status) {
ASSERT_TRUE(base::Contains(callbacks_, report_id));
std::move(callbacks_[report_id]).Run(status);
callbacks_.erase(report_id);
}
private:
int64_t unique_id_counter_ = 0;
std::map<int64_t, ReportSentCallback> callbacks_;
};
class TestAggregatableReportScheduler : public AggregatableReportScheduler {
public:
TestAggregatableReportScheduler(
AggregationServiceStorageContext* storage_context,
base::RepeatingCallback<
void(std::vector<AggregationServiceStorage::RequestAndId>)>
on_scheduled_report_time_reached)
: AggregatableReportScheduler(storage_context, base::DoNothing()),
on_scheduled_report_time_reached_(
std::move(on_scheduled_report_time_reached)) {}
~TestAggregatableReportScheduler() override = default;
void ScheduleRequest(AggregatableReportRequest request) override {
scheduled_reports_.emplace(unique_id_counter_++, std::move(request));
}
void NotifyInProgressRequestSucceeded(
AggregationServiceStorage::RequestId request_id) override {
completed_requests_status_[request_id] = true;
}
void NotifyInProgressRequestFailed(
AggregationServiceStorage::RequestId request_id) override {
completed_requests_status_[request_id] = false;
}
void TriggerReportingTime(
std::vector<AggregationServiceStorage::RequestId> request_ids) {
std::vector<AggregationServiceStorage::RequestAndId> return_value;
for (AggregationServiceStorage::RequestId request_id : request_ids) {
ASSERT_TRUE(base::Contains(scheduled_reports_, request_id));
return_value.push_back(AggregationServiceStorage::RequestAndId{
.request = std::move(scheduled_reports_.at(request_id)),
.id = request_id});
scheduled_reports_.erase(request_id);
}
on_scheduled_report_time_reached_.Run(std::move(return_value));
}
// Returns a boolean representing whether the request was successfully
// completed. Returns absl::nullopt if the request has not yet completed.
absl::optional<bool> WasRequestSuccessful(
AggregationServiceStorage::RequestId request_id) {
if (!base::Contains(completed_requests_status_, request_id)) {
return absl::nullopt;
}
return completed_requests_status_[request_id];
}
private:
base::RepeatingCallback<void(
std::vector<AggregationServiceStorage::RequestAndId>)>
on_scheduled_report_time_reached_;
int64_t unique_id_counter_ = 1;
base::flat_map<AggregationServiceStorage::RequestId,
AggregatableReportRequest>
scheduled_reports_;
// Each completed request's ID is the key, with value whether it was completed
// successfully.
base::flat_map<AggregationServiceStorage::RequestId, bool>
completed_requests_status_;
};
class AggregationServiceImplTest : public testing::Test {
public:
AggregationServiceImplTest()
: task_environment_(base::test::TaskEnvironment::TimeSource::MOCK_TIME),
storage_context_(task_environment_.GetMockClock()) {
EXPECT_TRUE(dir_.CreateUniqueTempDir());
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory =
base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
&test_url_loader_factory_);
auto assembler =
std::make_unique<TestAggregatableReportAssembler>(url_loader_factory);
test_assembler_ = assembler.get();
auto sender =
std::make_unique<TestAggregatableReportSender>(url_loader_factory);
test_sender_ = sender.get();
auto scheduler = std::make_unique<TestAggregatableReportScheduler>(
&storage_context_,
base::BindLambdaForTesting(
[this](std::vector<AggregationServiceStorage::RequestAndId>
requests_and_ids) {
service_impl_->OnScheduledReportTimeReached(
std::move(requests_and_ids));
}));
test_scheduler_ = scheduler.get();
service_impl_ = AggregationServiceImpl::CreateForTesting(
/*run_in_memory=*/true, dir_.GetPath(),
task_environment_.GetMockClock(), std::move(scheduler),
std::move(assembler), std::move(sender));
}
void AssembleReport(AggregatableReportRequest request) {
service()->AssembleReport(
std::move(request), base::BindLambdaForTesting(
[&](absl::optional<AggregatableReport> report,
AggregationService::AssemblyStatus status) {
last_assembled_report_ = std::move(report);
last_assembly_status_ = status;
}));
}
void SendReport(const GURL& url, const AggregatableReport& report) {
service()->SendReport(
url, report,
base::BindLambdaForTesting([&](AggregationService::SendStatus status) {
last_send_status_ = status;
}));
}
void ScheduleReport(AggregatableReportRequest request) {
service()->ScheduleReport(std::move(request));
}
AggregationServiceImpl* service() { return service_impl_.get(); }
TestAggregatableReportAssembler* assembler() { return test_assembler_; }
TestAggregatableReportSender* sender() { return test_sender_; }
TestAggregatableReportScheduler* scheduler() { return test_scheduler_; }
// Returns `absl::nullopt` if no report callback has been run or if the last
// assembly had an error.
const absl::optional<AggregatableReport>& last_assembled_report() const {
return last_assembled_report_;
}
// Returns `absl::nullopt` if no report callback has been run.
const absl::optional<AggregationService::AssemblyStatus>&
last_assembly_status() const {
return last_assembly_status_;
}
// Returns `absl::nullopt` if no report callback has been run.
const absl::optional<AggregationService::SendStatus>& last_send_status()
const {
return last_send_status_;
}
private:
base::ScopedTempDir dir_;
BrowserTaskEnvironment task_environment_;
network::TestURLLoaderFactory test_url_loader_factory_;
std::unique_ptr<AggregationServiceImpl> service_impl_;
TestAggregationServiceStorageContext storage_context_;
raw_ptr<TestAggregatableReportAssembler> test_assembler_ = nullptr;
raw_ptr<TestAggregatableReportSender> test_sender_ = nullptr;
raw_ptr<TestAggregatableReportScheduler> test_scheduler_ = nullptr;
absl::optional<AggregatableReport> last_assembled_report_;
absl::optional<AggregationService::AssemblyStatus> last_assembly_status_;
absl::optional<AggregationService::SendStatus> last_send_status_;
};
TEST_F(AggregationServiceImplTest, AssembleReport_Succeed) {
AggregatableReportRequest request =
aggregation_service::CreateExampleRequest();
AssembleReport(std::move(request));
std::vector<AggregatableReport::AggregationServicePayload> payloads;
payloads.emplace_back(/*payload=*/kABCD1234AsBytes,
/*key_id=*/"key_1",
/*debug_cleartext_payload=*/absl::nullopt);
AggregatableReport report(std::move(payloads), "example_shared_info");
assembler()->TriggerResponse(
/*report_id=*/0, std::move(report),
AggregatableReportAssembler::AssemblyStatus::kOk);
EXPECT_TRUE(last_assembled_report().has_value());
ASSERT_TRUE(last_assembly_status().has_value());
EXPECT_EQ(last_assembly_status().value(),
AggregationService::AssemblyStatus::kOk);
}
TEST_F(AggregationServiceImplTest, AssembleReport_Fail) {
AggregatableReportRequest request =
aggregation_service::CreateExampleRequest();
AssembleReport(std::move(request));
assembler()->TriggerResponse(
/*report_id=*/0, absl::nullopt,
AggregatableReportAssembler::AssemblyStatus::kPublicKeyFetchFailed);
EXPECT_FALSE(last_assembled_report().has_value());
ASSERT_TRUE(last_assembly_status().has_value());
EXPECT_EQ(last_assembly_status().value(),
AggregationService::AssemblyStatus::kPublicKeyFetchFailed);
}
TEST_F(AggregationServiceImplTest, SendReport) {
std::vector<AggregatableReport::AggregationServicePayload> payloads;
payloads.emplace_back(/*payload=*/kABCD1234AsBytes,
/*key_id=*/"key_1",
/*debug_cleartext_payload=*/absl::nullopt);
AggregatableReport report(std::move(payloads), "example_shared_info");
SendReport(GURL("https://example.com/reports"), report);
sender()->TriggerResponse(/*report_id=*/0,
AggregatableReportSender::RequestStatus::kOk);
ASSERT_TRUE(last_send_status().has_value());
EXPECT_EQ(last_send_status().value(), AggregationService::SendStatus::kOk);
}
TEST_F(AggregationServiceImplTest, ScheduleReport_Success) {
AggregatableReportRequest request =
aggregation_service::CreateExampleRequest();
ScheduleReport(std::move(request));
// Request IDs begin at 1.
scheduler()->TriggerReportingTime(
/*request_ids=*/{AggregationServiceStorage::RequestId(1)});
std::vector<AggregatableReport::AggregationServicePayload> payloads;
payloads.emplace_back(/*payload=*/kABCD1234AsBytes,
/*key_id=*/"key_1",
/*debug_cleartext_payload=*/absl::nullopt);
AggregatableReport report(std::move(payloads), "example_shared_info");
assembler()->TriggerResponse(
/*report_id=*/0, std::move(report),
AggregatableReportAssembler::AssemblyStatus::kOk);
sender()->TriggerResponse(/*report_id=*/0,
AggregatableReportSender::RequestStatus::kOk);
ASSERT_TRUE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(1))
.has_value());
EXPECT_TRUE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(1))
.value());
}
TEST_F(AggregationServiceImplTest, ScheduleReport_FailedAssembly) {
AggregatableReportRequest request =
aggregation_service::CreateExampleRequest();
ScheduleReport(std::move(request));
// Request IDs begin at 1.
scheduler()->TriggerReportingTime(
/*request_ids=*/{AggregationServiceStorage::RequestId(1)});
std::vector<AggregatableReport::AggregationServicePayload> payloads;
payloads.emplace_back(/*payload=*/kABCD1234AsBytes,
/*key_id=*/"key_1",
/*debug_cleartext_payload=*/absl::nullopt);
AggregatableReport report(std::move(payloads), "example_shared_info");
assembler()->TriggerResponse(
/*report_id=*/0, absl::nullopt,
AggregatableReportAssembler::AssemblyStatus::kAssemblyFailed);
ASSERT_TRUE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(1))
.has_value());
EXPECT_FALSE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(1))
.value());
}
TEST_F(AggregationServiceImplTest, ScheduleReport_FailedSending) {
AggregatableReportRequest request =
aggregation_service::CreateExampleRequest();
ScheduleReport(std::move(request));
scheduler()->TriggerReportingTime(
/*request_ids=*/{AggregationServiceStorage::RequestId(1)});
std::vector<AggregatableReport::AggregationServicePayload> payloads;
payloads.emplace_back(/*payload=*/kABCD1234AsBytes,
/*key_id=*/"key_1",
/*debug_cleartext_payload=*/absl::nullopt);
AggregatableReport report(std::move(payloads), "example_shared_info");
assembler()->TriggerResponse(
/*report_id=*/0, std::move(report),
AggregatableReportAssembler::AssemblyStatus::kOk);
sender()->TriggerResponse(
/*report_id=*/0, AggregatableReportSender::RequestStatus::kNetworkError);
ASSERT_TRUE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(1))
.has_value());
EXPECT_FALSE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(1))
.value());
}
TEST_F(AggregationServiceImplTest,
MultipleReportsReturnedFromScheduler_Success) {
AggregatableReportRequest request_1 =
aggregation_service::CreateExampleRequest();
AggregatableReportRequest request_2 =
aggregation_service::CreateExampleRequest();
ScheduleReport(std::move(request_1));
ScheduleReport(std::move(request_2));
// Request IDs begin at 1.
scheduler()->TriggerReportingTime(
/*request_ids=*/{AggregationServiceStorage::RequestId(1),
AggregationServiceStorage::RequestId(2)});
std::vector<AggregatableReport::AggregationServicePayload> payloads;
payloads.emplace_back(/*payload=*/kABCD1234AsBytes,
/*key_id=*/"key_1",
/*debug_cleartext_payload=*/absl::nullopt);
AggregatableReport report_1(payloads, "example_shared_info");
AggregatableReport report_2(payloads, "example_shared_info");
assembler()->TriggerResponse(
/*report_id=*/0, std::move(report_1),
AggregatableReportAssembler::AssemblyStatus::kOk);
assembler()->TriggerResponse(
/*report_id=*/1, std::move(report_2),
AggregatableReportAssembler::AssemblyStatus::kOk);
sender()->TriggerResponse(/*report_id=*/0,
AggregatableReportSender::RequestStatus::kOk);
sender()->TriggerResponse(/*report_id=*/1,
AggregatableReportSender::RequestStatus::kOk);
ASSERT_TRUE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(1))
.has_value());
EXPECT_TRUE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(1))
.value());
ASSERT_TRUE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(2))
.has_value());
EXPECT_TRUE(
scheduler()
->WasRequestSuccessful(AggregationServiceStorage::RequestId(2))
.value());
}
} // namespace content