blob: 5fd2615d56ddd15e6a0215f5711c81431c0a1786 [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/policy/test_support/policy_storage.h"
#include "base/big_endian.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "crypto/sha2.h"
namespace policy {
namespace {
const char kPolicyKeySeparator[] = "/";
std::string GetPolicyKey(const std::string& policy_type,
const std::string& entity_id) {
if (entity_id.empty())
return policy_type;
return base::StrCat({policy_type, kPolicyKeySeparator, entity_id});
}
} // namespace
PolicyStorage::PolicyStorage()
: signature_provider_(std::make_unique<SignatureProvider>()) {}
PolicyStorage::PolicyStorage(PolicyStorage&& policy_storage) = default;
PolicyStorage& PolicyStorage::operator=(PolicyStorage&& policy_storage) =
default;
PolicyStorage::~PolicyStorage() = default;
std::string PolicyStorage::GetPolicyPayload(
const std::string& policy_type,
const std::string& entity_id) const {
auto it = policy_payloads_.find(GetPolicyKey(policy_type, entity_id));
return it == policy_payloads_.end() ? std::string() : it->second;
}
std::vector<std::string> PolicyStorage::GetEntityIdsForType(
const std::string& policy_type) {
std::string prefix = base::StrCat({policy_type, kPolicyKeySeparator});
std::vector<std::string> ids;
const size_t prefix_length = prefix.length();
for (const auto& [policy_key, payload] : policy_payloads_) {
if (base::StartsWith(policy_key, prefix))
ids.push_back(policy_key.substr(prefix_length));
}
return ids;
}
void PolicyStorage::SetPolicyPayload(const std::string& policy_type,
const std::string& policy_payload) {
SetPolicyPayload(policy_type, std::string(), policy_payload);
}
void PolicyStorage::SetPolicyPayload(const std::string& policy_type,
const std::string& entity_id,
const std::string& policy_payload) {
policy_payloads_[GetPolicyKey(policy_type, entity_id)] = policy_payload;
}
void PolicyStorage::SetPsmEntry(const std::string& brand_serial_id,
const PolicyStorage::PsmEntry& psm_entry) {
psm_entries_[brand_serial_id] = psm_entry;
}
const PolicyStorage::PsmEntry* PolicyStorage::GetPsmEntry(
const std::string& brand_serial_id) const {
auto it = psm_entries_.find(brand_serial_id);
if (it == psm_entries_.end())
return nullptr;
return &it->second;
}
void PolicyStorage::SetInitialEnrollmentState(
const std::string& brand_serial_id,
const PolicyStorage::InitialEnrollmentState& initial_enrollment_state) {
initial_enrollment_states_[brand_serial_id] = initial_enrollment_state;
}
const PolicyStorage::InitialEnrollmentState*
PolicyStorage::GetInitialEnrollmentState(
const std::string& brand_serial_id) const {
auto it = initial_enrollment_states_.find(brand_serial_id);
if (it == initial_enrollment_states_.end())
return nullptr;
return &it->second;
}
std::vector<std::string> PolicyStorage::GetMatchingSerialHashes(
uint64_t modulus,
uint64_t remainder) const {
std::vector<std::string> hashes;
for (const auto& [serial, enrollment_state] : initial_enrollment_states_) {
uint64_t hash = 0UL;
uint8_t hash_bytes[sizeof(hash)];
crypto::SHA256HashString(serial, hash_bytes, sizeof(hash));
base::ReadBigEndian(hash_bytes, &hash);
if (hash % modulus == remainder) {
hashes.emplace_back(reinterpret_cast<const char*>(hash_bytes),
sizeof(hash));
}
}
return hashes;
}
} // namespace policy