blob: 4c8125b33786964dc10c462c521ddd393def9ffe [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/aggregation_service/aggregation_service_impl.h"
#include <stdint.h>
#include <map>
#include <memory>
#include <utility>
#include <vector>
#include "base/containers/contains.h"
#include "base/files/scoped_temp_dir.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/test/bind.h"
#include "base/time/time.h"
#include "content/browser/aggregation_service/aggregatable_report.h"
#include "content/browser/aggregation_service/aggregatable_report_assembler.h"
#include "content/browser/aggregation_service/aggregation_service_test_utils.h"
#include "content/public/test/browser_task_environment.h"
#include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h"
#include "services/network/test/test_url_loader_factory.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "url/gurl.h"
#include "url/origin.h"
namespace content {
class TestAggregatableReportAssembler : public AggregatableReportAssembler {
public:
TestAggregatableReportAssembler()
: AggregatableReportAssembler(
/*storage_context=*/nullptr,
base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>(
&test_url_loader_factory_)) {}
~TestAggregatableReportAssembler() override = default;
void AssembleReport(AggregatableReportRequest request,
AssemblyCallback callback) override {
callbacks_.emplace(unique_id_counter_++, std::move(callback));
}
void TriggerResponse(int64_t report_id,
absl::optional<AggregatableReport> report,
AssemblyStatus status) {
ASSERT_TRUE(base::Contains(callbacks_, report_id));
ASSERT_EQ(report.has_value(), status == AssemblyStatus::kOk);
std::move(callbacks_[report_id]).Run(std::move(report), status);
callbacks_.erase(report_id);
}
private:
int64_t unique_id_counter_ = 0;
std::map<int64_t, AssemblyCallback> callbacks_;
network::TestURLLoaderFactory test_url_loader_factory_;
};
class AggregationServiceImplTest : public testing::Test {
public:
AggregationServiceImplTest()
: task_environment_(base::test::TaskEnvironment::TimeSource::MOCK_TIME) {
EXPECT_TRUE(dir_.CreateUniqueTempDir());
auto assembler = std::make_unique<TestAggregatableReportAssembler>();
test_assembler_ = assembler.get();
service_impl_ = AggregationServiceImpl::CreateForTesting(
/*run_in_memory=*/true, dir_.GetPath(),
task_environment_.GetMockClock(), std::move(assembler));
}
void AssembleReport(AggregatableReportRequest request) {
service()->AssembleReport(
std::move(request), base::BindLambdaForTesting(
[&](absl::optional<AggregatableReport> report,
AggregationService::AssemblyStatus status) {
last_assembled_report_ = std::move(report);
last_assembly_status_ = status;
++num_assembly_callbacks_run_;
}));
}
AggregationServiceImpl* service() { return service_impl_.get(); }
TestAggregatableReportAssembler* assembler() { return test_assembler_; }
int num_assembly_callbacks_run() const { return num_assembly_callbacks_run_; }
// Should only be called after the report callback has been run.
const absl::optional<AggregatableReport>& last_assembled_report() const {
return last_assembled_report_;
}
// Should only be called after the report callback has been run.
const absl::optional<AggregationService::AssemblyStatus>&
last_assembly_status() const {
return last_assembly_status_;
}
private:
base::ScopedTempDir dir_;
BrowserTaskEnvironment task_environment_;
std::unique_ptr<AggregationServiceImpl> service_impl_;
raw_ptr<TestAggregatableReportAssembler> test_assembler_ = nullptr;
int num_assembly_callbacks_run_ = 0;
absl::optional<AggregatableReport> last_assembled_report_;
absl::optional<AggregationService::AssemblyStatus> last_assembly_status_;
};
TEST_F(AggregationServiceImplTest, AssembleReport_Succeed) {
AggregatableReportRequest request =
aggregation_service::CreateExampleRequest();
AssembleReport(std::move(request));
std::vector<AggregatableReport::AggregationServicePayload> payloads;
payloads.emplace_back(url::Origin::Create(GURL("https://a.example")),
/*payload=*/kABCD1234AsBytes,
/*key_id=*/"key_1");
payloads.emplace_back(url::Origin::Create(GURL("https://b.example")),
/*payload=*/kEFGH5678AsBytes,
/*key_id=*/"key_2");
AggregatableReportSharedInfo shared_info(
base::Time::FromJavaTime(1234567890123),
/*privacy_budget_key=*/"example_pbk");
AggregatableReport report(std::move(payloads), std::move(shared_info));
assembler()->TriggerResponse(
/*report_id=*/0, std::move(report),
AggregatableReportAssembler::AssemblyStatus::kOk);
EXPECT_EQ(num_assembly_callbacks_run(), 1);
EXPECT_TRUE(last_assembled_report().has_value());
ASSERT_TRUE(last_assembly_status().has_value());
EXPECT_EQ(last_assembly_status().value(),
AggregationService::AssemblyStatus::kOk);
}
TEST_F(AggregationServiceImplTest, AssembleReport_Fail) {
AggregatableReportRequest request =
aggregation_service::CreateExampleRequest();
AssembleReport(std::move(request));
assembler()->TriggerResponse(
/*report_id=*/0, absl::nullopt,
AggregatableReportAssembler::AssemblyStatus::kPublicKeyFetchFailed);
EXPECT_EQ(num_assembly_callbacks_run(), 1);
EXPECT_FALSE(last_assembled_report().has_value());
ASSERT_TRUE(last_assembly_status().has_value());
EXPECT_EQ(last_assembly_status().value(),
AggregationService::AssemblyStatus::kPublicKeyFetchFailed);
}
} // namespace content