| // Copyright 2021 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "content/browser/aggregation_service/aggregation_service_impl.h" |
| |
| #include <stdint.h> |
| |
| #include <map> |
| #include <memory> |
| #include <utility> |
| #include <vector> |
| |
| #include "base/containers/contains.h" |
| #include "base/files/scoped_temp_dir.h" |
| #include "base/memory/raw_ptr.h" |
| #include "base/memory/scoped_refptr.h" |
| #include "base/test/bind.h" |
| #include "base/time/time.h" |
| #include "content/browser/aggregation_service/aggregatable_report.h" |
| #include "content/browser/aggregation_service/aggregatable_report_assembler.h" |
| #include "content/browser/aggregation_service/aggregation_service_test_utils.h" |
| #include "content/public/test/browser_task_environment.h" |
| #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" |
| #include "services/network/test/test_url_loader_factory.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| #include "third_party/abseil-cpp/absl/types/optional.h" |
| #include "url/gurl.h" |
| #include "url/origin.h" |
| |
| namespace content { |
| |
| class TestAggregatableReportAssembler : public AggregatableReportAssembler { |
| public: |
| TestAggregatableReportAssembler() |
| : AggregatableReportAssembler( |
| /*storage_context=*/nullptr, |
| base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>( |
| &test_url_loader_factory_)) {} |
| ~TestAggregatableReportAssembler() override = default; |
| |
| void AssembleReport(AggregatableReportRequest request, |
| AssemblyCallback callback) override { |
| callbacks_.emplace(unique_id_counter_++, std::move(callback)); |
| } |
| |
| void TriggerResponse(int64_t report_id, |
| absl::optional<AggregatableReport> report, |
| AssemblyStatus status) { |
| ASSERT_TRUE(base::Contains(callbacks_, report_id)); |
| ASSERT_EQ(report.has_value(), status == AssemblyStatus::kOk); |
| |
| std::move(callbacks_[report_id]).Run(std::move(report), status); |
| callbacks_.erase(report_id); |
| } |
| |
| private: |
| int64_t unique_id_counter_ = 0; |
| std::map<int64_t, AssemblyCallback> callbacks_; |
| |
| network::TestURLLoaderFactory test_url_loader_factory_; |
| }; |
| |
| class AggregationServiceImplTest : public testing::Test { |
| public: |
| AggregationServiceImplTest() |
| : task_environment_(base::test::TaskEnvironment::TimeSource::MOCK_TIME) { |
| EXPECT_TRUE(dir_.CreateUniqueTempDir()); |
| |
| auto assembler = std::make_unique<TestAggregatableReportAssembler>(); |
| test_assembler_ = assembler.get(); |
| service_impl_ = AggregationServiceImpl::CreateForTesting( |
| /*run_in_memory=*/true, dir_.GetPath(), |
| task_environment_.GetMockClock(), std::move(assembler)); |
| } |
| |
| void AssembleReport(AggregatableReportRequest request) { |
| service()->AssembleReport( |
| std::move(request), base::BindLambdaForTesting( |
| [&](absl::optional<AggregatableReport> report, |
| AggregationService::AssemblyStatus status) { |
| last_assembled_report_ = std::move(report); |
| last_assembly_status_ = status; |
| ++num_assembly_callbacks_run_; |
| })); |
| } |
| |
| AggregationServiceImpl* service() { return service_impl_.get(); } |
| TestAggregatableReportAssembler* assembler() { return test_assembler_; } |
| |
| int num_assembly_callbacks_run() const { return num_assembly_callbacks_run_; } |
| |
| // Should only be called after the report callback has been run. |
| const absl::optional<AggregatableReport>& last_assembled_report() const { |
| return last_assembled_report_; |
| } |
| |
| // Should only be called after the report callback has been run. |
| const absl::optional<AggregationService::AssemblyStatus>& |
| last_assembly_status() const { |
| return last_assembly_status_; |
| } |
| |
| private: |
| base::ScopedTempDir dir_; |
| BrowserTaskEnvironment task_environment_; |
| std::unique_ptr<AggregationServiceImpl> service_impl_; |
| raw_ptr<TestAggregatableReportAssembler> test_assembler_ = nullptr; |
| |
| int num_assembly_callbacks_run_ = 0; |
| absl::optional<AggregatableReport> last_assembled_report_; |
| absl::optional<AggregationService::AssemblyStatus> last_assembly_status_; |
| }; |
| |
| TEST_F(AggregationServiceImplTest, AssembleReport_Succeed) { |
| AggregatableReportRequest request = |
| aggregation_service::CreateExampleRequest(); |
| |
| AssembleReport(std::move(request)); |
| |
| std::vector<AggregatableReport::AggregationServicePayload> payloads; |
| payloads.emplace_back(url::Origin::Create(GURL("https://a.example")), |
| /*payload=*/kABCD1234AsBytes, |
| /*key_id=*/"key_1"); |
| payloads.emplace_back(url::Origin::Create(GURL("https://b.example")), |
| /*payload=*/kEFGH5678AsBytes, |
| /*key_id=*/"key_2"); |
| |
| AggregatableReportSharedInfo shared_info( |
| base::Time::FromJavaTime(1234567890123), |
| /*privacy_budget_key=*/"example_pbk"); |
| |
| AggregatableReport report(std::move(payloads), std::move(shared_info)); |
| assembler()->TriggerResponse( |
| /*report_id=*/0, std::move(report), |
| AggregatableReportAssembler::AssemblyStatus::kOk); |
| |
| EXPECT_EQ(num_assembly_callbacks_run(), 1); |
| EXPECT_TRUE(last_assembled_report().has_value()); |
| ASSERT_TRUE(last_assembly_status().has_value()); |
| EXPECT_EQ(last_assembly_status().value(), |
| AggregationService::AssemblyStatus::kOk); |
| } |
| |
| TEST_F(AggregationServiceImplTest, AssembleReport_Fail) { |
| AggregatableReportRequest request = |
| aggregation_service::CreateExampleRequest(); |
| |
| AssembleReport(std::move(request)); |
| |
| assembler()->TriggerResponse( |
| /*report_id=*/0, absl::nullopt, |
| AggregatableReportAssembler::AssemblyStatus::kPublicKeyFetchFailed); |
| |
| EXPECT_EQ(num_assembly_callbacks_run(), 1); |
| EXPECT_FALSE(last_assembled_report().has_value()); |
| ASSERT_TRUE(last_assembly_status().has_value()); |
| EXPECT_EQ(last_assembly_status().value(), |
| AggregationService::AssemblyStatus::kPublicKeyFetchFailed); |
| } |
| |
| } // namespace content |