| // Copyright 2014 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/ssl/channel_id_service.h" |
| |
| #include <algorithm> |
| #include <limits> |
| #include <memory> |
| #include <utility> |
| |
| #include "base/atomic_sequence_num.h" |
| #include "base/bind.h" |
| #include "base/bind_helpers.h" |
| #include "base/callback_helpers.h" |
| #include "base/compiler_specific.h" |
| #include "base/location.h" |
| #include "base/logging.h" |
| #include "base/macros.h" |
| #include "base/memory/ref_counted.h" |
| #include "base/metrics/histogram_macros.h" |
| #include "base/rand_util.h" |
| #include "base/single_thread_task_runner.h" |
| #include "base/stl_util.h" |
| #include "base/task_runner.h" |
| #include "base/threading/thread_task_runner_handle.h" |
| #include "crypto/ec_private_key.h" |
| #include "net/base/net_errors.h" |
| #include "net/base/registry_controlled_domains/registry_controlled_domain.h" |
| #include "net/cert/x509_certificate.h" |
| #include "net/cert/x509_util.h" |
| #include "url/gurl.h" |
| |
| namespace net { |
| |
| namespace { |
| |
| base::StaticAtomicSequenceNumber g_next_id; |
| |
| // Used by the GetDomainBoundCertResult histogram to record the final |
| // outcome of each GetChannelID or GetOrCreateChannelID call. |
| // Do not re-use values. |
| enum GetChannelIDResult { |
| // Synchronously found and returned an existing domain bound cert. |
| SYNC_SUCCESS = 0, |
| // Retrieved or generated and returned a domain bound cert asynchronously. |
| ASYNC_SUCCESS = 1, |
| // Retrieval/generation request was cancelled before the cert generation |
| // completed. |
| ASYNC_CANCELLED = 2, |
| // Cert generation failed. |
| ASYNC_FAILURE_KEYGEN = 3, |
| // Result code 4 was removed (ASYNC_FAILURE_CREATE_CERT) |
| ASYNC_FAILURE_EXPORT_KEY = 5, |
| ASYNC_FAILURE_UNKNOWN = 6, |
| // GetChannelID or GetOrCreateChannelID was called with |
| // invalid arguments. |
| INVALID_ARGUMENT = 7, |
| // We don't support any of the cert types the server requested. |
| UNSUPPORTED_TYPE = 8, |
| // Server asked for a different type of certs while we were generating one. |
| TYPE_MISMATCH = 9, |
| // Couldn't start a worker to generate a cert. |
| WORKER_FAILURE = 10, |
| GET_CHANNEL_ID_RESULT_MAX |
| }; |
| |
| void RecordGetChannelIDResult(GetChannelIDResult result) { |
| UMA_HISTOGRAM_ENUMERATION("DomainBoundCerts.GetDomainBoundCertResult", result, |
| GET_CHANNEL_ID_RESULT_MAX); |
| } |
| |
| void RecordGetChannelIDTime(base::TimeDelta request_time) { |
| UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.GetCertTime", |
| request_time, |
| base::TimeDelta::FromMilliseconds(1), |
| base::TimeDelta::FromMinutes(5), |
| 50); |
| } |
| |
| // On success, returns a ChannelID object and sets |*error| to OK. |
| // Otherwise, returns NULL, and |*error| will be set to a net error code. |
| // |serial_number| is passed in because base::RandInt cannot be called from an |
| // unjoined thread, due to relying on a non-leaked LazyInstance |
| std::unique_ptr<ChannelIDStore::ChannelID> GenerateChannelID( |
| const std::string& server_identifier, |
| int* error) { |
| std::unique_ptr<ChannelIDStore::ChannelID> result; |
| |
| base::TimeTicks start = base::TimeTicks::Now(); |
| base::Time creation_time = base::Time::Now(); |
| std::unique_ptr<crypto::ECPrivateKey> key(crypto::ECPrivateKey::Create()); |
| |
| if (!key) { |
| DLOG(ERROR) << "Unable to create channel ID key pair"; |
| *error = ERR_KEY_GENERATION_FAILED; |
| return result; |
| } |
| |
| result.reset(new ChannelIDStore::ChannelID(server_identifier, creation_time, |
| std::move(key))); |
| UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.GenerateCertTime", |
| base::TimeTicks::Now() - start, |
| base::TimeDelta::FromMilliseconds(1), |
| base::TimeDelta::FromMinutes(5), |
| 50); |
| *error = OK; |
| return result; |
| } |
| |
| } // namespace |
| |
| // ChannelIDServiceWorker runs on a worker thread and takes care of the |
| // blocking process of performing key generation. Will take care of deleting |
| // itself once Start() is called. |
| class ChannelIDServiceWorker { |
| public: |
| typedef base::Callback< |
| void(const std::string&, int, std::unique_ptr<ChannelIDStore::ChannelID>)> |
| WorkerDoneCallback; |
| |
| ChannelIDServiceWorker(const std::string& server_identifier, |
| const WorkerDoneCallback& callback) |
| : server_identifier_(server_identifier), |
| origin_task_runner_(base::ThreadTaskRunnerHandle::Get()), |
| callback_(callback) {} |
| |
| // Starts the worker on |task_runner|. If the worker fails to start, such as |
| // if the task runner is shutting down, then it will take care of deleting |
| // itself. |
| bool Start(const scoped_refptr<base::TaskRunner>& task_runner) { |
| DCHECK(origin_task_runner_->RunsTasksOnCurrentThread()); |
| |
| return task_runner->PostTask( |
| FROM_HERE, |
| base::Bind(&ChannelIDServiceWorker::Run, base::Owned(this))); |
| } |
| |
| private: |
| void Run() { |
| // Runs on a worker thread. |
| int error = ERR_FAILED; |
| std::unique_ptr<ChannelIDStore::ChannelID> channel_id = |
| GenerateChannelID(server_identifier_, &error); |
| origin_task_runner_->PostTask( |
| FROM_HERE, base::Bind(callback_, server_identifier_, error, |
| base::Passed(&channel_id))); |
| } |
| |
| const std::string server_identifier_; |
| scoped_refptr<base::SequencedTaskRunner> origin_task_runner_; |
| WorkerDoneCallback callback_; |
| |
| DISALLOW_COPY_AND_ASSIGN(ChannelIDServiceWorker); |
| }; |
| |
| // A ChannelIDServiceJob is a one-to-one counterpart of an |
| // ChannelIDServiceWorker. It lives only on the ChannelIDService's |
| // origin task runner's thread. |
| class ChannelIDServiceJob { |
| public: |
| ChannelIDServiceJob(bool create_if_missing) |
| : create_if_missing_(create_if_missing) { |
| } |
| |
| ~ChannelIDServiceJob() { DCHECK(requests_.empty()); } |
| |
| void AddRequest(ChannelIDService::Request* request, |
| bool create_if_missing = false) { |
| create_if_missing_ |= create_if_missing; |
| requests_.push_back(request); |
| } |
| |
| void HandleResult(int error, std::unique_ptr<crypto::ECPrivateKey> key) { |
| PostAll(error, std::move(key)); |
| } |
| |
| bool CreateIfMissing() const { return create_if_missing_; } |
| |
| void CancelRequest(ChannelIDService::Request* req) { |
| auto it = std::find(requests_.begin(), requests_.end(), req); |
| if (it != requests_.end()) |
| requests_.erase(it); |
| } |
| |
| private: |
| void PostAll(int error, std::unique_ptr<crypto::ECPrivateKey> key) { |
| std::vector<ChannelIDService::Request*> requests; |
| requests_.swap(requests); |
| |
| for (std::vector<ChannelIDService::Request*>::iterator i = requests.begin(); |
| i != requests.end(); i++) { |
| std::unique_ptr<crypto::ECPrivateKey> key_copy; |
| if (key) |
| key_copy.reset(key->Copy()); |
| (*i)->Post(error, std::move(key_copy)); |
| } |
| } |
| |
| std::vector<ChannelIDService::Request*> requests_; |
| bool create_if_missing_; |
| }; |
| |
| // static |
| const char ChannelIDService::kEPKIPassword[] = ""; |
| |
| ChannelIDService::Request::Request() : service_(NULL) { |
| } |
| |
| ChannelIDService::Request::~Request() { |
| Cancel(); |
| } |
| |
| void ChannelIDService::Request::Cancel() { |
| if (service_) { |
| RecordGetChannelIDResult(ASYNC_CANCELLED); |
| callback_.Reset(); |
| job_->CancelRequest(this); |
| |
| service_ = NULL; |
| } |
| } |
| |
| void ChannelIDService::Request::RequestStarted( |
| ChannelIDService* service, |
| base::TimeTicks request_start, |
| const CompletionCallback& callback, |
| std::unique_ptr<crypto::ECPrivateKey>* key, |
| ChannelIDServiceJob* job) { |
| DCHECK(service_ == NULL); |
| service_ = service; |
| request_start_ = request_start; |
| callback_ = callback; |
| key_ = key; |
| job_ = job; |
| } |
| |
| void ChannelIDService::Request::Post( |
| int error, |
| std::unique_ptr<crypto::ECPrivateKey> key) { |
| switch (error) { |
| case OK: { |
| base::TimeDelta request_time = base::TimeTicks::Now() - request_start_; |
| UMA_HISTOGRAM_CUSTOM_TIMES("DomainBoundCerts.GetCertTimeAsync", |
| request_time, |
| base::TimeDelta::FromMilliseconds(1), |
| base::TimeDelta::FromMinutes(5), 50); |
| RecordGetChannelIDTime(request_time); |
| RecordGetChannelIDResult(ASYNC_SUCCESS); |
| break; |
| } |
| case ERR_KEY_GENERATION_FAILED: |
| RecordGetChannelIDResult(ASYNC_FAILURE_KEYGEN); |
| break; |
| case ERR_PRIVATE_KEY_EXPORT_FAILED: |
| RecordGetChannelIDResult(ASYNC_FAILURE_EXPORT_KEY); |
| break; |
| case ERR_INSUFFICIENT_RESOURCES: |
| RecordGetChannelIDResult(WORKER_FAILURE); |
| break; |
| default: |
| RecordGetChannelIDResult(ASYNC_FAILURE_UNKNOWN); |
| break; |
| } |
| service_ = NULL; |
| DCHECK(!callback_.is_null()); |
| if (key) |
| *key_ = std::move(key); |
| // Running the callback might delete |this| (e.g. the callback cleans up |
| // resources created for the request), so we can't touch any of our |
| // members afterwards. Reset callback_ first. |
| base::ResetAndReturn(&callback_).Run(error); |
| } |
| |
| ChannelIDService::ChannelIDService( |
| ChannelIDStore* channel_id_store, |
| const scoped_refptr<base::TaskRunner>& task_runner) |
| : channel_id_store_(channel_id_store), |
| task_runner_(task_runner), |
| id_(g_next_id.GetNext()), |
| requests_(0), |
| key_store_hits_(0), |
| inflight_joins_(0), |
| workers_created_(0), |
| weak_ptr_factory_(this) {} |
| |
| ChannelIDService::~ChannelIDService() { |
| STLDeleteValues(&inflight_); |
| } |
| |
| // static |
| std::string ChannelIDService::GetDomainForHost(const std::string& host) { |
| std::string domain = |
| registry_controlled_domains::GetDomainAndRegistry( |
| host, registry_controlled_domains::INCLUDE_PRIVATE_REGISTRIES); |
| if (domain.empty()) |
| return host; |
| return domain; |
| } |
| |
| int ChannelIDService::GetOrCreateChannelID( |
| const std::string& host, |
| std::unique_ptr<crypto::ECPrivateKey>* key, |
| const CompletionCallback& callback, |
| Request* out_req) { |
| DVLOG(1) << __FUNCTION__ << " " << host; |
| DCHECK(CalledOnValidThread()); |
| base::TimeTicks request_start = base::TimeTicks::Now(); |
| |
| if (callback.is_null() || !key || host.empty()) { |
| RecordGetChannelIDResult(INVALID_ARGUMENT); |
| return ERR_INVALID_ARGUMENT; |
| } |
| |
| std::string domain = GetDomainForHost(host); |
| if (domain.empty()) { |
| RecordGetChannelIDResult(INVALID_ARGUMENT); |
| return ERR_INVALID_ARGUMENT; |
| } |
| |
| requests_++; |
| |
| // See if a request for the same domain is currently in flight. |
| bool create_if_missing = true; |
| if (JoinToInFlightRequest(request_start, domain, key, create_if_missing, |
| callback, out_req)) { |
| return ERR_IO_PENDING; |
| } |
| |
| int err = LookupChannelID(request_start, domain, key, create_if_missing, |
| callback, out_req); |
| if (err == ERR_FILE_NOT_FOUND) { |
| // Sync lookup did not find a valid channel ID. Start generating a new one. |
| workers_created_++; |
| ChannelIDServiceWorker* worker = new ChannelIDServiceWorker( |
| domain, |
| base::Bind(&ChannelIDService::GeneratedChannelID, |
| weak_ptr_factory_.GetWeakPtr())); |
| if (!worker->Start(task_runner_)) { |
| // TODO(rkn): Log to the NetLog. |
| LOG(ERROR) << "ChannelIDServiceWorker couldn't be started."; |
| RecordGetChannelIDResult(WORKER_FAILURE); |
| return ERR_INSUFFICIENT_RESOURCES; |
| } |
| // We are waiting for key generation. Create a job & request to track it. |
| ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing); |
| inflight_[domain] = job; |
| |
| job->AddRequest(out_req); |
| out_req->RequestStarted(this, request_start, callback, key, job); |
| return ERR_IO_PENDING; |
| } |
| |
| return err; |
| } |
| |
| int ChannelIDService::GetChannelID(const std::string& host, |
| std::unique_ptr<crypto::ECPrivateKey>* key, |
| const CompletionCallback& callback, |
| Request* out_req) { |
| DVLOG(1) << __FUNCTION__ << " " << host; |
| DCHECK(CalledOnValidThread()); |
| base::TimeTicks request_start = base::TimeTicks::Now(); |
| |
| if (callback.is_null() || !key || host.empty()) { |
| RecordGetChannelIDResult(INVALID_ARGUMENT); |
| return ERR_INVALID_ARGUMENT; |
| } |
| |
| std::string domain = GetDomainForHost(host); |
| if (domain.empty()) { |
| RecordGetChannelIDResult(INVALID_ARGUMENT); |
| return ERR_INVALID_ARGUMENT; |
| } |
| |
| requests_++; |
| |
| // See if a request for the same domain currently in flight. |
| bool create_if_missing = false; |
| if (JoinToInFlightRequest(request_start, domain, key, create_if_missing, |
| callback, out_req)) { |
| return ERR_IO_PENDING; |
| } |
| |
| int err = LookupChannelID(request_start, domain, key, create_if_missing, |
| callback, out_req); |
| return err; |
| } |
| |
| void ChannelIDService::GotChannelID(int err, |
| const std::string& server_identifier, |
| std::unique_ptr<crypto::ECPrivateKey> key) { |
| DCHECK(CalledOnValidThread()); |
| |
| std::map<std::string, ChannelIDServiceJob*>::iterator j; |
| j = inflight_.find(server_identifier); |
| if (j == inflight_.end()) { |
| NOTREACHED(); |
| return; |
| } |
| |
| if (err == OK) { |
| // Async DB lookup found a valid channel ID. |
| key_store_hits_++; |
| // ChannelIDService::Request::Post will do the histograms and stuff. |
| HandleResult(OK, server_identifier, std::move(key)); |
| return; |
| } |
| // Async lookup failed or the channel ID was missing. Return the error |
| // directly, unless the channel ID was missing and a request asked to create |
| // one. |
| if (err != ERR_FILE_NOT_FOUND || !j->second->CreateIfMissing()) { |
| HandleResult(err, server_identifier, std::move(key)); |
| return; |
| } |
| // At least one request asked to create a channel ID => start generating a new |
| // one. |
| workers_created_++; |
| ChannelIDServiceWorker* worker = new ChannelIDServiceWorker( |
| server_identifier, |
| base::Bind(&ChannelIDService::GeneratedChannelID, |
| weak_ptr_factory_.GetWeakPtr())); |
| if (!worker->Start(task_runner_)) { |
| // TODO(rkn): Log to the NetLog. |
| LOG(ERROR) << "ChannelIDServiceWorker couldn't be started."; |
| HandleResult(ERR_INSUFFICIENT_RESOURCES, server_identifier, nullptr); |
| } |
| } |
| |
| ChannelIDStore* ChannelIDService::GetChannelIDStore() { |
| return channel_id_store_.get(); |
| } |
| |
| void ChannelIDService::GeneratedChannelID( |
| const std::string& server_identifier, |
| int error, |
| std::unique_ptr<ChannelIDStore::ChannelID> channel_id) { |
| DCHECK(CalledOnValidThread()); |
| |
| std::unique_ptr<crypto::ECPrivateKey> key; |
| if (error == OK) { |
| key.reset(channel_id->key()->Copy()); |
| channel_id_store_->SetChannelID(std::move(channel_id)); |
| } |
| HandleResult(error, server_identifier, std::move(key)); |
| } |
| |
| void ChannelIDService::HandleResult(int error, |
| const std::string& server_identifier, |
| std::unique_ptr<crypto::ECPrivateKey> key) { |
| DCHECK(CalledOnValidThread()); |
| |
| std::map<std::string, ChannelIDServiceJob*>::iterator j; |
| j = inflight_.find(server_identifier); |
| if (j == inflight_.end()) { |
| NOTREACHED(); |
| return; |
| } |
| ChannelIDServiceJob* job = j->second; |
| inflight_.erase(j); |
| |
| job->HandleResult(error, std::move(key)); |
| delete job; |
| } |
| |
| bool ChannelIDService::JoinToInFlightRequest( |
| const base::TimeTicks& request_start, |
| const std::string& domain, |
| std::unique_ptr<crypto::ECPrivateKey>* key, |
| bool create_if_missing, |
| const CompletionCallback& callback, |
| Request* out_req) { |
| ChannelIDServiceJob* job = NULL; |
| std::map<std::string, ChannelIDServiceJob*>::const_iterator j = |
| inflight_.find(domain); |
| if (j != inflight_.end()) { |
| // A request for the same domain is in flight already. We'll attach our |
| // callback, but we'll also mark it as requiring a channel ID if one's |
| // mising. |
| job = j->second; |
| inflight_joins_++; |
| |
| job->AddRequest(out_req, create_if_missing); |
| out_req->RequestStarted(this, request_start, callback, key, job); |
| return true; |
| } |
| return false; |
| } |
| |
| int ChannelIDService::LookupChannelID( |
| const base::TimeTicks& request_start, |
| const std::string& domain, |
| std::unique_ptr<crypto::ECPrivateKey>* key, |
| bool create_if_missing, |
| const CompletionCallback& callback, |
| Request* out_req) { |
| // Check if a channel ID key already exists for this domain. |
| int err = channel_id_store_->GetChannelID( |
| domain, key, base::Bind(&ChannelIDService::GotChannelID, |
| weak_ptr_factory_.GetWeakPtr())); |
| |
| if (err == OK) { |
| // Sync lookup found a valid channel ID. |
| DVLOG(1) << "Channel ID store had valid key for " << domain; |
| key_store_hits_++; |
| RecordGetChannelIDResult(SYNC_SUCCESS); |
| base::TimeDelta request_time = base::TimeTicks::Now() - request_start; |
| UMA_HISTOGRAM_TIMES("DomainBoundCerts.GetCertTimeSync", request_time); |
| RecordGetChannelIDTime(request_time); |
| return OK; |
| } |
| |
| if (err == ERR_IO_PENDING) { |
| // We are waiting for async DB lookup. Create a job & request to track it. |
| ChannelIDServiceJob* job = new ChannelIDServiceJob(create_if_missing); |
| inflight_[domain] = job; |
| |
| job->AddRequest(out_req); |
| out_req->RequestStarted(this, request_start, callback, key, job); |
| return ERR_IO_PENDING; |
| } |
| |
| return err; |
| } |
| |
| int ChannelIDService::channel_id_count() { |
| return channel_id_store_->GetChannelIDCount(); |
| } |
| |
| } // namespace net |