blob: 42f658078ee893058427968953576baf2ff98f3d [file] [log] [blame]
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/policy/messaging_layer/upload/record_handler_impl.h"
#include "base/base64.h"
#include "base/json/json_writer.h"
#include "base/optional.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/synchronization/waitable_event.h"
#include "base/task/post_task.h"
#include "base/task_runner.h"
#include "base/test/task_environment.h"
#include "base/values.h"
#include "chrome/browser/policy/messaging_layer/upload/dm_server_upload_service.h"
#include "chrome/browser/policy/messaging_layer/util/status.h"
#include "chrome/browser/policy/messaging_layer/util/status_macros.h"
#include "chrome/browser/policy/messaging_layer/util/statusor.h"
#include "chrome/browser/policy/messaging_layer/util/task_runner_context.h"
#include "components/policy/core/common/cloud/cloud_policy_client.h"
#include "components/policy/core/common/cloud/dm_token.h"
#include "components/policy/core/common/cloud/mock_cloud_policy_client.h"
#include "components/policy/proto/record.pb.h"
#include "components/policy/proto/record_constants.pb.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"
using ::testing::_;
using ::testing::Eq;
using ::testing::Invoke;
using ::testing::MockFunction;
using ::testing::Property;
using ::testing::Return;
using ::testing::StrictMock;
using ::testing::WithArgs;
namespace reporting {
namespace {
MATCHER_P(ValueEqualsProto,
expected,
"Compares StatusOr<MessageLite> to expected MessageLite") {
if (!arg.ok()) {
return false;
}
if (arg.ValueOrDie().GetTypeName() != expected.GetTypeName()) {
return false;
}
return arg.ValueOrDie().SerializeAsString() == expected.SerializeAsString();
}
class TestCallbackWaiter {
public:
TestCallbackWaiter() = default;
virtual void Signal() { run_loop_.Quit(); }
void Wait() { run_loop_.Run(); }
protected:
base::RunLoop run_loop_;
};
class TestCallbackWaiterWithCounter : public TestCallbackWaiter {
public:
explicit TestCallbackWaiterWithCounter(size_t counter_limit)
: counter_limit_(counter_limit) {}
void Signal() override {
DCHECK_GT(counter_limit_, 0u);
if (--counter_limit_ == 0u) {
run_loop_.Quit();
}
}
private:
std::atomic<size_t> counter_limit_;
};
using TestCompletionResponder =
MockFunction<void(DmServerUploadService::CompletionResponse)>;
using TestEncryptionKeyAttached = MockFunction<void(SignedEncryptionInfo)>;
// Helper function for retrieving and processing the SequencingInformation from
// a request.
void RetrieveFinalSequencingInformation(const base::Value& request,
base::Value& sequencing_info) {
ASSERT_TRUE(request.is_dict());
// Retrieve and process sequencing information
const base::Value* const encrypted_record_list =
request.FindListKey("encryptedRecord");
ASSERT_TRUE(encrypted_record_list != nullptr);
ASSERT_FALSE(encrypted_record_list->GetList().empty());
const auto* seq_info = encrypted_record_list->GetList().rbegin()->FindDictKey(
"sequencingInformation");
ASSERT_TRUE(seq_info != nullptr);
ASSERT_TRUE(!seq_info->FindStringKey("sequencingId")->empty());
ASSERT_TRUE(!seq_info->FindStringKey("generationId")->empty());
ASSERT_TRUE(seq_info->FindIntKey("priority"));
sequencing_info.MergeDictionary(seq_info);
// Set half of sequencing information to return a string instead of an int for
// priority.
int64_t sequencing_id;
ASSERT_TRUE(base::StringToInt64(
*sequencing_info.FindStringKey("sequencingId"), &sequencing_id));
if (sequencing_id % 2) {
const auto int_result = sequencing_info.FindIntKey("priority");
ASSERT_TRUE(int_result.has_value());
sequencing_info.RemoveKey("priority");
sequencing_info.SetStringKey("priority", Priority_Name(int_result.value()));
}
}
base::Optional<base::Value> BuildEncryptionSettingsFromRequest(
const base::Value& request) {
// If attach_encryption_settings it true, process that.
const auto attach_encryption_settings =
request.FindBoolKey("attachEncryptionSettings");
if (!attach_encryption_settings.has_value() ||
!attach_encryption_settings.value()) {
return base::nullopt;
}
base::Value encryption_settings{base::Value::Type::DICTIONARY};
std::string public_key;
base::Base64Encode("PUBLIC KEY", &public_key);
encryption_settings.SetStringKey("publicKey", public_key);
encryption_settings.SetIntKey("publicKeyId", 12345);
std::string public_key_signature;
// TODO(b/170054326): Generate signature.
base::Base64Encode("PUBLIC KEY SIG", &public_key_signature);
encryption_settings.SetStringKey("publicKeySignature", public_key_signature);
return encryption_settings;
}
// Immitates the server response for successful record upload. Since additional
// steps and tests require the response from the server to be accurate, ASSERTS
// that the |request| must be valid, and on a valid request updates |response|.
void SucceedResponseFromRequest(const base::Value& request,
base::Value& response) {
base::Value seq_info{base::Value::Type::DICTIONARY};
RetrieveFinalSequencingInformation(request, seq_info);
response.SetPath("lastSucceedUploadedRecord", std::move(seq_info));
// If attach_encryption_settings it true, process that.
auto encryption_settings_result = BuildEncryptionSettingsFromRequest(request);
if (encryption_settings_result.has_value()) {
response.SetPath("encryptionSettings",
std::move(encryption_settings_result.value()));
}
}
// Immitates the server response for failed record upload. Since additional
// steps and tests require the response from the server to be accurate, ASSERTS
// that the |request| must be valid, and on a valid request updates |response|.
void FailedResponseFromRequest(const base::Value& request,
base::Value& response) {
base::Value seq_info{base::Value::Type::DICTIONARY};
RetrieveFinalSequencingInformation(request, seq_info);
response.SetPath("lastSucceedUploadedRecord", seq_info.Clone());
// The lastSucceedUploadedRecord should be the record before the one indicated
// in seq_info. |seq_info| has been built by
// RetrieveFinalSequencingInforamation and is guaranteed to have this key.
int64_t sequencing_id;
ASSERT_TRUE(base::StringToInt64(*seq_info.FindStringKey("sequencingId"),
&sequencing_id));
response.SetStringPath("lastSucceedUploadedRecord.sequencingId",
base::NumberToString(sequencing_id - 1));
// The firstFailedUploadedRecord.failedUploadedRecord should be the one
// indicated in seq_info.
response.SetPath("firstFailedUploadedRecord.failedUploadedRecord",
std::move(seq_info));
auto encryption_settings_result = BuildEncryptionSettingsFromRequest(request);
if (encryption_settings_result.has_value()) {
response.SetPath("encryptionSettings",
std::move(encryption_settings_result.value()));
}
}
class RecordHandlerImplTest : public ::testing::TestWithParam<bool> {
public:
RecordHandlerImplTest()
: client_(std::make_unique<policy::MockCloudPolicyClient>()) {}
protected:
void SetUp() override {
client_->SetDMToken(
policy::DMToken::CreateValidTokenForTesting("FAKE_DM_TOKEN").value());
}
bool need_encryption_key() const { return GetParam(); }
content::BrowserTaskEnvironment task_environment_;
std::unique_ptr<policy::MockCloudPolicyClient> client_;
};
std::unique_ptr<std::vector<EncryptedRecord>> BuildTestRecordsVector(
size_t number_of_test_records,
int64_t generation_id) {
std::unique_ptr<std::vector<EncryptedRecord>> test_records =
std::make_unique<std::vector<EncryptedRecord>>();
test_records->reserve(number_of_test_records);
for (size_t i = 0; i < number_of_test_records; i++) {
EncryptedRecord encrypted_record;
encrypted_record.set_encrypted_wrapped_record(
base::StrCat({"Record Number ", base::NumberToString(i)}));
auto* sequencing_information =
encrypted_record.mutable_sequencing_information();
sequencing_information->set_generation_id(generation_id);
sequencing_information->set_sequencing_id(i);
sequencing_information->set_priority(Priority::IMMEDIATE);
test_records->push_back(std::move(encrypted_record));
}
return test_records;
}
TEST_P(RecordHandlerImplTest, ForwardsRecordsToCloudPolicyClient) {
constexpr size_t kNumTestRecords = 10;
constexpr int64_t kGenerationId = 1234;
auto test_records = BuildTestRecordsVector(kNumTestRecords, kGenerationId);
TestCallbackWaiterWithCounter client_waiter{kNumTestRecords};
EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _))
.Times(kNumTestRecords)
.WillRepeatedly(WithArgs<0, 2>(
Invoke([&client_waiter](
base::Value request,
policy::CloudPolicyClient::ResponseCallback callback) {
base::Value response{base::Value::Type::DICTIONARY};
SucceedResponseFromRequest(request, response);
std::move(callback).Run(std::move(response));
client_waiter.Signal();
})));
RecordHandlerImpl handler(client_.get());
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
StrictMock<TestCompletionResponder> responder;
TestCallbackWaiter responder_waiter;
EXPECT_CALL(encryption_key_attached, Call(_))
.Times(need_encryption_key() ? 1 : 0);
EXPECT_CALL(
responder,
Call(ValueEqualsProto(test_records->back().sequencing_information())))
.WillOnce(Invoke([&responder_waiter]() { responder_waiter.Signal(); }));
auto encryption_key_attached_callback =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
auto responder_callback = base::BindOnce(&TestCompletionResponder::Call,
base::Unretained(&responder));
handler.HandleRecords(need_encryption_key(), std::move(test_records),
std::move(responder_callback),
encryption_key_attached_callback);
client_waiter.Wait();
responder_waiter.Wait();
}
TEST_P(RecordHandlerImplTest, ReportsEarlyFailure) {
const int64_t kNumSuccessfulUploads = 5;
const int64_t kNumTestRecords = 10;
const int64_t kGenerationId = 1234;
auto test_records = BuildTestRecordsVector(kNumTestRecords, kGenerationId);
// Wait kNumSuccessfulUploads times + 1 for the failure.
TestCallbackWaiterWithCounter client_waiter{kNumSuccessfulUploads + 1};
{
::testing::InSequence seq;
EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _))
.Times(kNumSuccessfulUploads)
.WillRepeatedly(WithArgs<0, 2>(
Invoke([&client_waiter](
base::Value request,
policy::CloudPolicyClient::ResponseCallback callback) {
base::Value response{base::Value::Type::DICTIONARY};
SucceedResponseFromRequest(request, response);
std::move(callback).Run(std::move(response));
client_waiter.Signal();
})));
EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _))
.WillOnce(WithArgs<2>(Invoke(
[&client_waiter](base::OnceCallback<void(
base::Optional<base::Value>)> callback) {
std::move(callback).Run(base::nullopt);
client_waiter.Signal();
})));
}
RecordHandlerImpl handler(client_.get());
StrictMock<TestCompletionResponder> responder;
TestCallbackWaiter responder_waiter;
EXPECT_CALL(
responder,
Call(ValueEqualsProto(
(*test_records)[kNumSuccessfulUploads - 1].sequencing_information())))
.WillOnce(Invoke([&responder_waiter]() { responder_waiter.Signal(); }));
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
EXPECT_CALL(encryption_key_attached, Call(_))
.Times(need_encryption_key() ? 1 : 0);
auto encryption_key_attached_callback =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
auto responder_callback = base::BindOnce(&TestCompletionResponder::Call,
base::Unretained(&responder));
handler.HandleRecords(need_encryption_key(), std::move(test_records),
std::move(responder_callback),
encryption_key_attached_callback);
client_waiter.Wait();
responder_waiter.Wait();
}
TEST_P(RecordHandlerImplTest, UploadsGapRecordOnServerFailure) {
const int64_t kNumInitialSuccessfulUploads = 5;
const int64_t kNumTestRecords = 10;
const int64_t kNumFinalSuccessfulUploads =
kNumTestRecords - kNumInitialSuccessfulUploads;
const int64_t kGenerationId = 1234;
auto test_records = BuildTestRecordsVector(kNumTestRecords, kGenerationId);
// Wait kNumTestRecords times + 1 for the failure.
TestCallbackWaiterWithCounter client_waiter{kNumTestRecords + 1};
{
::testing::InSequence seq;
EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _))
.Times(kNumInitialSuccessfulUploads)
.WillRepeatedly(WithArgs<0, 2>(
Invoke([&client_waiter](
base::Value request,
policy::CloudPolicyClient::ResponseCallback callback) {
base::Value response{base::Value::Type::DICTIONARY};
SucceedResponseFromRequest(request, response);
std::move(callback).Run(std::move(response));
client_waiter.Signal();
})));
EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _))
.WillOnce(WithArgs<0, 2>(
Invoke([&client_waiter](
base::Value request,
policy::CloudPolicyClient::ResponseCallback callback) {
base::Value response{base::Value::Type::DICTIONARY};
FailedResponseFromRequest(request, response);
std::move(callback).Run(std::move(response));
client_waiter.Signal();
})));
EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _))
.Times(kNumFinalSuccessfulUploads)
.WillRepeatedly(WithArgs<0, 2>(
Invoke([&client_waiter](
base::Value request,
policy::CloudPolicyClient::ResponseCallback callback) {
base::Value response{base::Value::Type::DICTIONARY};
SucceedResponseFromRequest(request, response);
std::move(callback).Run(std::move(response));
client_waiter.Signal();
})));
}
RecordHandlerImpl handler(client_.get());
StrictMock<TestCallbackWaiter> responder_waiter;
TestCompletionResponder responder;
EXPECT_CALL(
responder,
Call(ValueEqualsProto(
(*test_records)[kNumTestRecords - 1].sequencing_information())))
.WillOnce(Invoke([&responder_waiter]() { responder_waiter.Signal(); }));
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
EXPECT_CALL(encryption_key_attached, Call(_))
.Times(need_encryption_key() ? 1 : 0);
auto encryption_key_attached_callback =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
auto responder_callback = base::BindOnce(&TestCompletionResponder::Call,
base::Unretained(&responder));
handler.HandleRecords(need_encryption_key(), std::move(test_records),
std::move(responder_callback),
encryption_key_attached_callback);
client_waiter.Wait();
responder_waiter.Wait();
}
// There may be cases where the server and the client do not align in the
// expected response, clients shouldn't crash in these instances, but simply
// report an internal error.
TEST_P(RecordHandlerImplTest, HandleUnknownResponseFromServer) {
constexpr size_t kNumTestRecords = 10;
constexpr int64_t kGenerationId = 1234;
auto test_records = BuildTestRecordsVector(kNumTestRecords, kGenerationId);
TestCallbackWaiterWithCounter client_waiter{kNumTestRecords};
EXPECT_CALL(*client_, UploadEncryptedReport(_, _, _))
.Times(kNumTestRecords)
.WillRepeatedly(WithArgs<2>(
Invoke([&client_waiter](
policy::CloudPolicyClient::ResponseCallback callback) {
std::move(callback).Run(base::Value{base::Value::Type::DICTIONARY});
client_waiter.Signal();
})));
RecordHandlerImpl handler(client_.get());
StrictMock<TestEncryptionKeyAttached> encryption_key_attached;
StrictMock<TestCompletionResponder> responder;
TestCallbackWaiter responder_waiter;
EXPECT_CALL(encryption_key_attached, Call(_)).Times(0);
EXPECT_CALL(
responder,
Call(Property(&StatusOr<SequencingInformation>::status,
Property(&Status::error_code, Eq(error::INTERNAL)))))
.WillOnce(Invoke([&responder_waiter]() { responder_waiter.Signal(); }));
auto encryption_key_attached_callback =
base::BindRepeating(&TestEncryptionKeyAttached::Call,
base::Unretained(&encryption_key_attached));
auto responder_callback = base::BindOnce(&TestCompletionResponder::Call,
base::Unretained(&responder));
handler.HandleRecords(need_encryption_key(), std::move(test_records),
std::move(responder_callback),
encryption_key_attached_callback);
client_waiter.Wait();
responder_waiter.Wait();
}
INSTANTIATE_TEST_SUITE_P(NeedOrNoNeedKey,
RecordHandlerImplTest,
testing::Bool());
} // namespace
} // namespace reporting