| // Copyright 2023 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/unexportable_keys/unexportable_key_service_impl.h" |
| |
| #include <algorithm> |
| #include <variant> |
| |
| #include "base/containers/to_vector.h" |
| #include "base/functional/bind.h" |
| #include "base/functional/callback.h" |
| #include "base/logging.h" |
| #include "base/memory/weak_ptr.h" |
| #include "base/types/expected.h" |
| #include "components/unexportable_keys/service_error.h" |
| #include "components/unexportable_keys/unexportable_key_id.h" |
| #include "components/unexportable_keys/unexportable_key_task_manager.h" |
| #include "crypto/unexportable_key.h" |
| |
| namespace unexportable_keys { |
| |
| // Class holding either an `UnexportableKeyId` or a list of callbacks waiting |
| // for the key creation. |
| class MaybePendingUnexportableKeyId { |
| public: |
| using CallbackType = |
| base::OnceCallback<void(ServiceErrorOr<UnexportableKeyId>)>; |
| using PendingCallbacks = std::vector<CallbackType>; |
| using PendingCallbacksOrKeyId = |
| std::variant<PendingCallbacks, UnexportableKeyId>; |
| |
| // Constructs an instance holding a list of callbacks. |
| MaybePendingUnexportableKeyId() = default; |
| |
| // Constructs an instance holding `key_id`. |
| explicit MaybePendingUnexportableKeyId(UnexportableKeyId key_id) |
| : pending_callbacks_or_key_id_(key_id) {} |
| |
| // Returns true if a key has been assigned to this instance. Otherwise, |
| // returns false which means that this instance holds a list of callbacks. |
| bool HasKeyId() const { |
| return std::holds_alternative<UnexportableKeyId>( |
| pending_callbacks_or_key_id_); |
| } |
| |
| // This method should be called only if `HasKeyId()` is true. |
| UnexportableKeyId GetKeyId() const { |
| CHECK(HasKeyId()); |
| return std::get<UnexportableKeyId>(pending_callbacks_or_key_id_); |
| } |
| |
| // These methods should be called only if `HasKeyId()` is false. |
| |
| // Adds `callback` to the list of callbacks and returns size of the list. |
| size_t AddCallback(CallbackType callback) { |
| CHECK(!HasKeyId()); |
| GetCallbacks().push_back(std::move(callback)); |
| return GetCallbacks().size(); |
| } |
| |
| void SetKeyIdAndRunCallbacks(UnexportableKeyId key_id) { |
| CHECK(!HasKeyId()); |
| PendingCallbacksOrKeyId pending_callbacks = |
| std::exchange(pending_callbacks_or_key_id_, key_id); |
| for (auto& callback : std::get<PendingCallbacks>(pending_callbacks)) { |
| std::move(callback).Run(key_id); |
| } |
| } |
| |
| void RunCallbacksWithFailure(ServiceError error) { |
| CHECK(!HasKeyId()); |
| for (auto& callback : std::exchange(GetCallbacks(), PendingCallbacks())) { |
| std::move(callback).Run(base::unexpected(error)); |
| } |
| } |
| |
| private: |
| PendingCallbacks& GetCallbacks() { |
| CHECK(!HasKeyId()); |
| return std::get<PendingCallbacks>(pending_callbacks_or_key_id_); |
| } |
| |
| // Holds the value of its first alternative type by default. |
| PendingCallbacksOrKeyId pending_callbacks_or_key_id_; |
| }; |
| |
| UnexportableKeyServiceImpl::UnexportableKeyServiceImpl( |
| UnexportableKeyTaskManager& task_manager, |
| crypto::UnexportableKeyProvider::Config config) |
| : task_manager_(task_manager), config_(config) {} |
| |
| UnexportableKeyServiceImpl::~UnexportableKeyServiceImpl() = default; |
| |
| // static |
| bool UnexportableKeyServiceImpl::IsUnexportableKeyProviderSupported( |
| crypto::UnexportableKeyProvider::Config config) { |
| return UnexportableKeyTaskManager::GetUnexportableKeyProvider( |
| std::move(config)) != nullptr; |
| } |
| |
| void UnexportableKeyServiceImpl::GenerateSigningKeySlowlyAsync( |
| base::span<const crypto::SignatureVerifier::SignatureAlgorithm> |
| acceptable_algorithms, |
| BackgroundTaskPriority priority, |
| base::OnceCallback<void(ServiceErrorOr<UnexportableKeyId>)> callback) { |
| task_manager_->GenerateSigningKeySlowlyAsync( |
| config_, acceptable_algorithms, priority, |
| base::BindOnce(&UnexportableKeyServiceImpl::OnKeyGenerated, |
| generate_key_weak_ptr_factory_.GetWeakPtr(), |
| std::move(callback))); |
| } |
| |
| void UnexportableKeyServiceImpl::FromWrappedSigningKeySlowlyAsync( |
| base::span<const uint8_t> wrapped_key, |
| BackgroundTaskPriority priority, |
| base::OnceCallback<void(ServiceErrorOr<UnexportableKeyId>)> callback) { |
| auto& [wrapped_key_vec, maybe_pending_key_id] = |
| *key_id_by_wrapped_key_.lazy_emplace(wrapped_key, [&](const auto& ctor) { |
| ctor(base::ToVector(wrapped_key), MaybePendingUnexportableKeyId()); |
| }); |
| |
| if (maybe_pending_key_id.HasKeyId()) { |
| std::move(callback).Run(maybe_pending_key_id.GetKeyId()); |
| return; |
| } |
| |
| size_t n_callbacks = maybe_pending_key_id.AddCallback(std::move(callback)); |
| if (n_callbacks == 1) { |
| // `callback` is the first one waiting for the wrapped key. Schedule the |
| // task to create a key from the wrapped key. |
| task_manager_->FromWrappedSigningKeySlowlyAsync( |
| config_, wrapped_key, priority, |
| base::BindOnce(&UnexportableKeyServiceImpl::OnKeyCreatedFromWrappedKey, |
| from_wrapped_key_weak_ptr_factory_.GetWeakPtr(), |
| wrapped_key_vec)); |
| } |
| } |
| |
| void UnexportableKeyServiceImpl:: |
| GetAllSigningKeysForGarbageCollectionSlowlyAsync( |
| BackgroundTaskPriority priority, |
| base::OnceCallback<void(ServiceErrorOr<std::vector<UnexportableKeyId>>)> |
| callback) { |
| // TODO: crbug.com/455538141 - Implement key retrieval in the task manager. |
| std::move(callback).Run(std::vector<UnexportableKeyId>()); |
| } |
| |
| void UnexportableKeyServiceImpl::SignSlowlyAsync( |
| UnexportableKeyId key_id, |
| base::span<const uint8_t> data, |
| BackgroundTaskPriority priority, |
| base::OnceCallback<void(ServiceErrorOr<std::vector<uint8_t>>)> callback) { |
| auto it = key_by_key_id_.find(key_id); |
| if (it == key_by_key_id_.end()) { |
| std::move(callback).Run(base::unexpected(ServiceError::kKeyNotFound)); |
| return; |
| } |
| task_manager_->SignSlowlyAsync(it->second, data, priority, |
| std::move(callback)); |
| } |
| |
| void UnexportableKeyServiceImpl::DeleteKeySlowlyAsync( |
| UnexportableKeyId key_id, |
| BackgroundTaskPriority priority, |
| base::OnceCallback<void(ServiceErrorOr<void>)> callback) { |
| auto key_id_it = key_by_key_id_.find(key_id); |
| if (key_id_it == key_by_key_id_.end()) { |
| std::move(callback).Run(base::unexpected(ServiceError::kKeyNotFound)); |
| return; |
| } |
| |
| const std::vector<uint8_t> wrapped_key = |
| key_id_it->second->key().GetWrappedKey(); |
| auto wrapped_key_it = key_id_by_wrapped_key_.find(wrapped_key); |
| CHECK(wrapped_key_it != key_id_by_wrapped_key_.end()); |
| CHECK(wrapped_key_it->second.HasKeyId()); |
| CHECK_EQ(wrapped_key_it->second.GetKeyId(), key_id); |
| |
| key_by_key_id_.erase(key_id_it); |
| key_id_by_wrapped_key_.erase(wrapped_key_it); |
| |
| // TODO: crbug.com/455538141 - Implement deletion in the task manager. |
| std::move(callback).Run(base::ok()); |
| } |
| |
| void UnexportableKeyServiceImpl::DeleteAllKeysSlowlyAsync( |
| BackgroundTaskPriority priority, |
| base::OnceCallback<void(ServiceErrorOr<void>)> callback) { |
| key_by_key_id_.clear(); |
| |
| // Clear the in-memory cache of pending key IDs by moving it to a local |
| // variable and run pending callbacks with a failure. |
| for (auto& [_, maybe_pending_key_id] : |
| std::exchange(key_id_by_wrapped_key_, {})) { |
| if (!maybe_pending_key_id.HasKeyId()) { |
| maybe_pending_key_id.RunCallbacksWithFailure(ServiceError::kKeyNotFound); |
| } |
| } |
| |
| // Invalidate weak pointers to cancel pending from wrapped key requests. |
| from_wrapped_key_weak_ptr_factory_.InvalidateWeakPtrs(); |
| |
| // TODO: crbug.com/455538141 - Implement deletion in the task manager. |
| std::move(callback).Run(base::ok()); |
| } |
| |
| ServiceErrorOr<std::vector<uint8_t>> |
| UnexportableKeyServiceImpl::GetSubjectPublicKeyInfo( |
| UnexportableKeyId key_id) const { |
| auto it = key_by_key_id_.find(key_id); |
| if (it == key_by_key_id_.end()) { |
| return base::unexpected(ServiceError::kKeyNotFound); |
| } |
| return it->second->key().GetSubjectPublicKeyInfo(); |
| } |
| |
| ServiceErrorOr<std::vector<uint8_t>> UnexportableKeyServiceImpl::GetWrappedKey( |
| UnexportableKeyId key_id) const { |
| auto it = key_by_key_id_.find(key_id); |
| if (it == key_by_key_id_.end()) { |
| return base::unexpected(ServiceError::kKeyNotFound); |
| } |
| return it->second->key().GetWrappedKey(); |
| } |
| |
| ServiceErrorOr<crypto::SignatureVerifier::SignatureAlgorithm> |
| UnexportableKeyServiceImpl::GetAlgorithm(UnexportableKeyId key_id) const { |
| auto it = key_by_key_id_.find(key_id); |
| if (it == key_by_key_id_.end()) { |
| return base::unexpected(ServiceError::kKeyNotFound); |
| } |
| return it->second->key().Algorithm(); |
| } |
| |
| void UnexportableKeyServiceImpl::OnKeyGenerated( |
| base::OnceCallback<void(ServiceErrorOr<UnexportableKeyId>)> client_callback, |
| ServiceErrorOr<scoped_refptr<RefCountedUnexportableSigningKey>> |
| key_or_error) { |
| std::move(client_callback).Run([&]() -> ServiceErrorOr<UnexportableKeyId> { |
| if (!key_or_error.has_value()) { |
| return base::unexpected(key_or_error.error()); |
| } |
| scoped_refptr<RefCountedUnexportableSigningKey>& key = key_or_error.value(); |
| // `key` must be non-null if `key_or_error` holds a value. |
| CHECK(key); |
| UnexportableKeyId key_id = key->id(); |
| if (!key_id_by_wrapped_key_.try_emplace(key->key().GetWrappedKey(), key_id) |
| .second) { |
| // Drop a newly generated key in the case of a key collision. This should |
| // be extremely rare. |
| DVLOG(1) << "Collision between an existing and a newly generated key " |
| "detected."; |
| return base::unexpected(ServiceError::kKeyCollision); |
| } |
| // A newly generated key ID must be unique. |
| CHECK(key_by_key_id_.try_emplace(key_id, std::move(key)).second); |
| return key_id; |
| }()); |
| } |
| |
| void UnexportableKeyServiceImpl::OnKeyCreatedFromWrappedKey( |
| std::vector<uint8_t> wrapped_key, |
| ServiceErrorOr<scoped_refptr<RefCountedUnexportableSigningKey>> |
| key_or_error) { |
| auto it = key_id_by_wrapped_key_.find(wrapped_key); |
| CHECK(it != key_id_by_wrapped_key_.end()); |
| |
| auto& [_, pending_callbacks] = *it; |
| CHECK(!pending_callbacks.HasKeyId()); |
| |
| if (!key_or_error.has_value()) { |
| auto node = key_id_by_wrapped_key_.extract(it); |
| node.mapped().RunCallbacksWithFailure(key_or_error.error()); |
| return; |
| } |
| scoped_refptr<RefCountedUnexportableSigningKey>& key = key_or_error.value(); |
| // `key` must be non-null if `key_or_error` holds a value. |
| CHECK(key); |
| DCHECK(wrapped_key == key->key().GetWrappedKey()); |
| |
| UnexportableKeyId key_id = key->id(); |
| // A newly created key ID must be unique. |
| CHECK(key_by_key_id_.try_emplace(key_id, std::move(key)).second); |
| pending_callbacks.SetKeyIdAndRunCallbacks(key_id); |
| } |
| |
| } // namespace unexportable_keys |