blob: c69f2cec5c44fe136f992d799543b3e4adbce44f [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/aggregation_service/aggregation_service_impl.h"
#include <stdint.h>
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "base/containers/contains.h"
#include "base/files/scoped_temp_dir.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/test/bind.h"
#include "base/time/time.h"
#include "content/browser/aggregation_service/aggregatable_report.h"
#include "content/browser/aggregation_service/aggregatable_report_assembler.h"
#include "content/browser/aggregation_service/aggregation_service_test_utils.h"
#include "content/public/test/browser_task_environment.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "url/gurl.h"
namespace content {
class TestAggregatableReportAssembler : public AggregatableReportAssembler {
public:
explicit TestAggregatableReportAssembler(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
: AggregatableReportAssembler(
/*storage_context=*/nullptr,
std::move(url_loader_factory)) {}
~TestAggregatableReportAssembler() override = default;
void AssembleReport(AggregatableReportRequest request,
AssemblyCallback callback) override {
callbacks_.emplace(unique_id_counter_++, std::move(callback));
}
void TriggerResponse(int64_t report_id,
absl::optional<AggregatableReport> report,
AssemblyStatus status) {
ASSERT_TRUE(base::Contains(callbacks_, report_id));
ASSERT_EQ(report.has_value(), status == AssemblyStatus::kOk);
std::move(callbacks_[report_id]).Run(std::move(report), status);
callbacks_.erase(report_id);
}
private:
int64_t unique_id_counter_ = 0;
std::map<int64_t, AssemblyCallback> callbacks_;
};
class TestAggregatableReportSender : public AggregatableReportSender {
public:
explicit TestAggregatableReportSender(
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory)
: AggregatableReportSender(std::move(url_loader_factory)) {}
~TestAggregatableReportSender() override = default;
void SendReport(const GURL& url,
const base::Value& contents,
ReportSentCallback callback) override {
callbacks_.emplace(unique_id_counter_++, std::move(callback));
}
void TriggerResponse(int64_t report_id, RequestStatus status) {
ASSERT_TRUE(base::Contains(callbacks_, report_id));
std::move(callbacks_[report_id]).Run(status);
callbacks_.erase(report_id);
}
private:
int64_t unique_id_counter_ = 0;
std::map<int64_t, ReportSentCallback> callbacks_;
};
class AggregationServiceImplTest : public testing::Test {
public:
AggregationServiceImplTest()
: task_environment_(base::test::TaskEnvironment::TimeSource::MOCK_TIME) {
EXPECT_TRUE(dir_.CreateUniqueTempDir());
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory =
base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
&test_url_loader_factory_);
auto assembler =
std::make_unique<TestAggregatableReportAssembler>(url_loader_factory);
test_assembler_ = assembler.get();
auto sender =
std::make_unique<TestAggregatableReportSender>(url_loader_factory);
test_sender_ = sender.get();
service_impl_ = AggregationServiceImpl::CreateForTesting(
/*run_in_memory=*/true, dir_.GetPath(),
task_environment_.GetMockClock(), std::move(assembler),
std::move(sender));
}
void AssembleReport(AggregatableReportRequest request) {
service()->AssembleReport(
std::move(request), base::BindLambdaForTesting(
[&](absl::optional<AggregatableReport> report,
AggregationService::AssemblyStatus status) {
last_assembled_report_ = std::move(report);
last_assembly_status_ = status;
}));
}
void SendReport(const GURL& url, const AggregatableReport& report) {
service()->SendReport(
url, report,
base::BindLambdaForTesting([&](AggregationService::SendStatus status) {
last_send_status_ = status;
}));
}
AggregationServiceImpl* service() { return service_impl_.get(); }
TestAggregatableReportAssembler* assembler() { return test_assembler_; }
TestAggregatableReportSender* sender() { return test_sender_; }
// Returns `absl::nullopt` if no report callback has been run or if the last
// assembly had an error.
const absl::optional<AggregatableReport>& last_assembled_report() const {
return last_assembled_report_;
}
// Returns `absl::nullopt` if no report callback has been run.
const absl::optional<AggregationService::AssemblyStatus>&
last_assembly_status() const {
return last_assembly_status_;
}
// Returns `absl::nullopt` if no report callback has been run.
const absl::optional<AggregationService::SendStatus>& last_send_status()
const {
return last_send_status_;
}
private:
base::ScopedTempDir dir_;
BrowserTaskEnvironment task_environment_;
network::TestURLLoaderFactory test_url_loader_factory_;
std::unique_ptr<AggregationServiceImpl> service_impl_;
raw_ptr<TestAggregatableReportAssembler> test_assembler_ = nullptr;
raw_ptr<TestAggregatableReportSender> test_sender_ = nullptr;
absl::optional<AggregatableReport> last_assembled_report_;
absl::optional<AggregationService::AssemblyStatus> last_assembly_status_;
absl::optional<AggregationService::SendStatus> last_send_status_;
};
TEST_F(AggregationServiceImplTest, AssembleReport_Succeed) {
AggregatableReportRequest request =
aggregation_service::CreateExampleRequest();
AssembleReport(std::move(request));
std::vector<AggregatableReport::AggregationServicePayload> payloads;
payloads.emplace_back(/*payload=*/kABCD1234AsBytes,
/*key_id=*/"key_1",
/*debug_cleartext_payload=*/absl::nullopt);
payloads.emplace_back(/*payload=*/kEFGH5678AsBytes,
/*key_id=*/"key_2",
/*debug_cleartext_payload=*/absl::nullopt);
AggregatableReport report(std::move(payloads), "example_shared_info");
assembler()->TriggerResponse(
/*report_id=*/0, std::move(report),
AggregatableReportAssembler::AssemblyStatus::kOk);
EXPECT_TRUE(last_assembled_report().has_value());
ASSERT_TRUE(last_assembly_status().has_value());
EXPECT_EQ(last_assembly_status().value(),
AggregationService::AssemblyStatus::kOk);
}
TEST_F(AggregationServiceImplTest, AssembleReport_Fail) {
AggregatableReportRequest request =
aggregation_service::CreateExampleRequest();
AssembleReport(std::move(request));
assembler()->TriggerResponse(
/*report_id=*/0, absl::nullopt,
AggregatableReportAssembler::AssemblyStatus::kPublicKeyFetchFailed);
EXPECT_FALSE(last_assembled_report().has_value());
ASSERT_TRUE(last_assembly_status().has_value());
EXPECT_EQ(last_assembly_status().value(),
AggregationService::AssemblyStatus::kPublicKeyFetchFailed);
}
TEST_F(AggregationServiceImplTest, SendReport) {
std::vector<AggregatableReport::AggregationServicePayload> payloads;
payloads.emplace_back(/*payload=*/kABCD1234AsBytes,
/*key_id=*/"key_1",
/*debug_cleartext_payload=*/absl::nullopt);
payloads.emplace_back(/*payload=*/kEFGH5678AsBytes,
/*key_id=*/"key_2",
/*debug_cleartext_payload=*/absl::nullopt);
AggregatableReport report(std::move(payloads), "example_shared_info");
SendReport(GURL("https://example.com/reports"), report);
sender()->TriggerResponse(/*report_id=*/0,
AggregatableReportSender::RequestStatus::kOk);
ASSERT_TRUE(last_send_status().has_value());
EXPECT_EQ(last_send_status().value(), AggregationService::SendStatus::kOk);
}
} // namespace content