| // Copyright 2023 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/shared_dictionary/shared_dictionary_network_transaction.h" |
| |
| #include <optional> |
| #include <string> |
| #include <string_view> |
| |
| #include "base/base64.h" |
| #include "base/feature_list.h" |
| #include "base/functional/bind.h" |
| #include "base/functional/callback_helpers.h" |
| #include "base/metrics/histogram_functions.h" |
| #include "base/metrics/histogram_macros.h" |
| #include "base/notreached.h" |
| #include "base/strings/strcat.h" |
| #include "base/strings/string_number_conversions.h" |
| #include "base/strings/string_util.h" |
| #include "base/types/expected.h" |
| #include "net/base/completion_once_callback.h" |
| #include "net/base/features.h" |
| #include "net/base/hash_value.h" |
| #include "net/base/io_buffer.h" |
| #include "net/base/load_flags.h" |
| #include "net/base/net_errors.h" |
| #include "net/base/transport_info.h" |
| #include "net/base/url_util.h" |
| #include "net/cert/x509_certificate.h" |
| #include "net/filter/brotli_source_stream.h" |
| #include "net/filter/filter_source_stream.h" |
| #include "net/filter/source_stream.h" |
| #include "net/filter/source_stream_type.h" |
| #include "net/filter/zstd_source_stream.h" |
| #include "net/http/http_request_info.h" |
| #include "net/http/structured_headers.h" |
| #include "net/shared_dictionary/shared_dictionary_constants.h" |
| #include "net/shared_dictionary/shared_dictionary_header_checker_source_stream.h" |
| #include "net/shared_dictionary/shared_dictionary_isolation_key.h" |
| #include "net/ssl/ssl_private_key.h" |
| |
| namespace net { |
| |
| namespace { |
| |
| // Convert the interface from HttpTransaction to SourceStream. |
| class ProxyingSourceStream : public SourceStream { |
| public: |
| explicit ProxyingSourceStream(HttpTransaction* transaction) |
| : SourceStream(SourceStreamType::kNone), transaction_(transaction) {} |
| |
| ProxyingSourceStream(const ProxyingSourceStream&) = delete; |
| ProxyingSourceStream& operator=(const ProxyingSourceStream&) = delete; |
| |
| ~ProxyingSourceStream() override = default; |
| |
| // SourceStream implementation: |
| int Read(IOBuffer* dest_buffer, |
| int buffer_size, |
| CompletionOnceCallback callback) override { |
| DCHECK(transaction_); |
| return transaction_->Read(dest_buffer, buffer_size, std::move(callback)); |
| } |
| |
| std::string Description() const override { return std::string(); } |
| |
| bool MayHaveMoreBytes() const override { return true; } |
| |
| private: |
| const raw_ptr<HttpTransaction> transaction_; |
| }; |
| |
| void AddAcceptEncoding(HttpRequestHeaders* request_headers, |
| std::string_view encoding_header) { |
| std::optional<std::string> accept_encoding = |
| request_headers->GetHeader(HttpRequestHeaders::kAcceptEncoding); |
| request_headers->SetHeader( |
| HttpRequestHeaders::kAcceptEncoding, |
| accept_encoding ? base::StrCat({*accept_encoding, ", ", encoding_header}) |
| : std::string(encoding_header)); |
| } |
| |
| } // namespace |
| |
| SharedDictionaryNetworkTransaction::PendingReadTask::PendingReadTask( |
| IOBuffer* buf, |
| int buf_len, |
| CompletionOnceCallback callback) |
| : buf(buf), buf_len(buf_len), callback(std::move(callback)) {} |
| |
| SharedDictionaryNetworkTransaction::PendingReadTask::~PendingReadTask() = |
| default; |
| |
| SharedDictionaryNetworkTransaction::SharedDictionaryNetworkTransaction( |
| std::unique_ptr<HttpTransaction> network_transaction, |
| bool enable_shared_zstd) |
| : enable_shared_zstd_(enable_shared_zstd), |
| network_transaction_(std::move(network_transaction)) { |
| network_transaction_->SetConnectedCallback( |
| base::BindRepeating(&SharedDictionaryNetworkTransaction::OnConnected, |
| base::Unretained(this))); |
| } |
| |
| SharedDictionaryNetworkTransaction::~SharedDictionaryNetworkTransaction() = |
| default; |
| |
| int SharedDictionaryNetworkTransaction::Start(const HttpRequestInfo* request, |
| CompletionOnceCallback callback, |
| const NetLogWithSource& net_log) { |
| if (!(request->load_flags & LOAD_CAN_USE_SHARED_DICTIONARY) || |
| !request->dictionary_getter) { |
| return network_transaction_->Start(request, std::move(callback), net_log); |
| } |
| std::optional<SharedDictionaryIsolationKey> isolation_key = |
| SharedDictionaryIsolationKey::MaybeCreate(request->network_isolation_key, |
| request->frame_origin); |
| shared_dictionary_getter_ = base::BindRepeating(request->dictionary_getter, |
| isolation_key, request->url); |
| |
| // Safe to bind unretained `this` because the callback is owned by |
| // `network_transaction_` which is owned by `this`. |
| network_transaction_->SetModifyRequestHeadersCallback(base::BindRepeating( |
| &SharedDictionaryNetworkTransaction::ModifyRequestHeaders, |
| base::Unretained(this), request->url)); |
| return network_transaction_->Start( |
| request, |
| base::BindOnce(&SharedDictionaryNetworkTransaction::OnStartCompleted, |
| base::Unretained(this), std::move(callback)), |
| net_log); |
| } |
| |
| SharedDictionaryNetworkTransaction::SharedDictionaryEncodingType |
| SharedDictionaryNetworkTransaction::ParseSharedDictionaryEncodingType( |
| const HttpResponseHeaders& headers) { |
| std::optional<std::string> content_encoding = |
| headers.GetNormalizedHeader("Content-Encoding"); |
| if (!content_encoding) { |
| return SharedDictionaryEncodingType::kNotUsed; |
| } else if (content_encoding == |
| shared_dictionary::kSharedBrotliContentEncodingName) { |
| return SharedDictionaryEncodingType::kSharedBrotli; |
| } else if (enable_shared_zstd_ && |
| content_encoding == |
| shared_dictionary::kSharedZstdContentEncodingName) { |
| return SharedDictionaryEncodingType::kSharedZstd; |
| } |
| return SharedDictionaryEncodingType::kNotUsed; |
| } |
| |
| void SharedDictionaryNetworkTransaction::OnStartCompleted( |
| CompletionOnceCallback callback, |
| int result) { |
| if (shared_dictionary_) { |
| base::UmaHistogramSparse( |
| base::StrCat({"Net.SharedDictionaryTransaction.NetResultWithDict.", |
| cert_is_issued_by_known_root_ |
| ? "KnownRootCert" |
| : "UnknownRootCertOrNoCert"}), |
| -result); |
| } |
| |
| if (result != OK || !shared_dictionary_) { |
| std::move(callback).Run(result); |
| return; |
| } |
| |
| shared_dictionary_encoding_type_ = ParseSharedDictionaryEncodingType( |
| *network_transaction_->GetResponseInfo()->headers); |
| if (shared_dictionary_encoding_type_ == |
| SharedDictionaryEncodingType::kNotUsed) { |
| std::move(callback).Run(result); |
| return; |
| } |
| |
| shared_dictionary_used_response_info_ = std::make_unique<HttpResponseInfo>( |
| *network_transaction_->GetResponseInfo()); |
| shared_dictionary_used_response_info_->did_use_shared_dictionary = true; |
| std::move(callback).Run(result); |
| } |
| |
| void SharedDictionaryNetworkTransaction::ModifyRequestHeaders( |
| const GURL& request_url, |
| HttpRequestHeaders* request_headers) { |
| // `shared_dictionary_` may have been already set if this transaction was |
| // restarted |
| if (!shared_dictionary_) { |
| shared_dictionary_ = shared_dictionary_getter_.Run(); |
| } |
| if (!shared_dictionary_) { |
| return; |
| } |
| |
| if (base::FeatureList::IsEnabled( |
| features::kCompressionDictionaryTransportRequireKnownRootCert) && |
| !cert_is_issued_by_known_root_ && !IsLocalhost(request_url)) { |
| shared_dictionary_.reset(); |
| return; |
| } |
| |
| // `is_shared_dictionary_read_allowed_callback_` triggers a notification of |
| // the shared dictionary usage to the browser process. So we need to call |
| // `is_shared_dictionary_read_allowed_callback_` after checking the result |
| // of `GetDictionarySync()`. |
| CHECK(is_shared_dictionary_read_allowed_callback_); |
| if (!is_shared_dictionary_read_allowed_callback_.Run()) { |
| shared_dictionary_.reset(); |
| return; |
| } |
| dictionary_hash_base64_ = |
| base::StrCat({":", base::Base64Encode(shared_dictionary_->hash()), ":"}); |
| request_headers->SetHeader(shared_dictionary::kAvailableDictionaryHeaderName, |
| dictionary_hash_base64_); |
| if (enable_shared_zstd_) { |
| AddAcceptEncoding( |
| request_headers, |
| base::StrCat({shared_dictionary::kSharedBrotliContentEncodingName, ", ", |
| shared_dictionary::kSharedZstdContentEncodingName})); |
| } else { |
| AddAcceptEncoding(request_headers, |
| shared_dictionary::kSharedBrotliContentEncodingName); |
| } |
| |
| if (!shared_dictionary_->id().empty()) { |
| std::optional<std::string> serialized_id = |
| structured_headers::SerializeItem(shared_dictionary_->id()); |
| if (serialized_id) { |
| request_headers->SetHeader("Dictionary-ID", *serialized_id); |
| } |
| } |
| |
| if (dictionary_status_ == DictionaryStatus::kNoDictionary) { |
| dictionary_status_ = DictionaryStatus::kReading; |
| auto split_callback = base::SplitOnceCallback(base::BindOnce( |
| [](base::WeakPtr<SharedDictionaryNetworkTransaction> self, |
| base::Time read_start_time, int result) { |
| if (!self) { |
| bool succeeded = result == OK; |
| base::UmaHistogramTimes( |
| base::StrCat({"Net.SharedDictionaryTransaction." |
| "AbortedWhileReadingDictionary.", |
| succeeded ? "Success" : "Failure"}), |
| base::Time::Now() - read_start_time); |
| return; |
| } |
| self->OnReadSharedDictionary(read_start_time, result); |
| }, |
| weak_factory_.GetWeakPtr(), /*read_start_time=*/base::Time::Now())); |
| |
| int read_result = |
| shared_dictionary_->ReadAll(std::move(split_callback.first)); |
| if (read_result != ERR_IO_PENDING) { |
| std::move(split_callback.second).Run(read_result); |
| } |
| } |
| } |
| |
| void SharedDictionaryNetworkTransaction::OnReadSharedDictionary( |
| base::Time read_start_time, |
| int result) { |
| bool succeeded = result == OK; |
| base::UmaHistogramTimes( |
| base::StrCat({"Net.SharedDictionaryTransaction.DictionaryReadLatency.", |
| succeeded ? "Success" : "Failure"}), |
| base::Time::Now() - read_start_time); |
| if (!succeeded) { |
| dictionary_status_ = DictionaryStatus::kFailed; |
| } else { |
| dictionary_status_ = DictionaryStatus::kFinished; |
| CHECK(shared_dictionary_->data()); |
| } |
| if (pending_read_task_) { |
| auto task = std::move(pending_read_task_); |
| auto split_callback = base::SplitOnceCallback(std::move(task->callback)); |
| int ret = |
| Read(task->buf.get(), task->buf_len, std::move(split_callback.first)); |
| if (ret != ERR_IO_PENDING) { |
| std::move(split_callback.second).Run(ret); |
| } |
| } |
| } |
| |
| int SharedDictionaryNetworkTransaction::RestartIgnoringLastError( |
| CompletionOnceCallback callback) { |
| shared_dictionary_used_response_info_.reset(); |
| return network_transaction_->RestartIgnoringLastError( |
| base::BindOnce(&SharedDictionaryNetworkTransaction::OnStartCompleted, |
| base::Unretained(this), std::move(callback))); |
| } |
| |
| int SharedDictionaryNetworkTransaction::RestartWithCertificate( |
| scoped_refptr<X509Certificate> client_cert, |
| scoped_refptr<SSLPrivateKey> client_private_key, |
| CompletionOnceCallback callback) { |
| shared_dictionary_used_response_info_.reset(); |
| return network_transaction_->RestartWithCertificate( |
| std::move(client_cert), std::move(client_private_key), |
| base::BindOnce(&SharedDictionaryNetworkTransaction::OnStartCompleted, |
| base::Unretained(this), std::move(callback))); |
| } |
| |
| int SharedDictionaryNetworkTransaction::RestartWithAuth( |
| const AuthCredentials& credentials, |
| CompletionOnceCallback callback) { |
| shared_dictionary_used_response_info_.reset(); |
| return network_transaction_->RestartWithAuth( |
| credentials, |
| base::BindOnce(&SharedDictionaryNetworkTransaction::OnStartCompleted, |
| base::Unretained(this), std::move(callback))); |
| } |
| |
| bool SharedDictionaryNetworkTransaction::IsReadyToRestartForAuth() { |
| return network_transaction_->IsReadyToRestartForAuth(); |
| } |
| |
| int SharedDictionaryNetworkTransaction::Read(IOBuffer* buf, |
| int buf_len, |
| CompletionOnceCallback callback) { |
| if (!shared_dictionary_used_response_info_) { |
| return network_transaction_->Read(buf, buf_len, std::move(callback)); |
| } |
| |
| switch (dictionary_status_) { |
| case DictionaryStatus::kNoDictionary: |
| NOTREACHED(); |
| case DictionaryStatus::kReading: |
| CHECK(!pending_read_task_); |
| pending_read_task_ = |
| std::make_unique<PendingReadTask>(buf, buf_len, std::move(callback)); |
| return ERR_IO_PENDING; |
| case DictionaryStatus::kFinished: |
| if (!shared_compression_stream_) { |
| // Wrap the source `network_transaction_` with a |
| // SharedDictionaryHeaderCheckerSourceStream to check the header |
| // of Dictionary-Compressed stream. |
| std::unique_ptr<SourceStream> header_checker_source_stream = |
| std::make_unique<SharedDictionaryHeaderCheckerSourceStream>( |
| std::make_unique<ProxyingSourceStream>( |
| network_transaction_.get()), |
| shared_dictionary_encoding_type_ == |
| SharedDictionaryEncodingType::kSharedBrotli |
| ? SharedDictionaryHeaderCheckerSourceStream::Type:: |
| kDictionaryCompressedBrotli |
| : SharedDictionaryHeaderCheckerSourceStream::Type:: |
| kDictionaryCompressedZstd, |
| shared_dictionary_->hash()); |
| if (shared_dictionary_encoding_type_ == |
| SharedDictionaryEncodingType::kSharedBrotli) { |
| SCOPED_UMA_HISTOGRAM_TIMER_MICROS( |
| "Network.SharedDictionary." |
| "CreateBrotliSourceStreamWithDictionary"); |
| shared_compression_stream_ = CreateBrotliSourceStreamWithDictionary( |
| std::move(header_checker_source_stream), |
| shared_dictionary_->data(), shared_dictionary_->size()); |
| } else if (shared_dictionary_encoding_type_ == |
| SharedDictionaryEncodingType::kSharedZstd) { |
| SCOPED_UMA_HISTOGRAM_TIMER_MICROS( |
| "Network.SharedDictionary.CreateZstdSourceStreamWithDictionary"); |
| shared_compression_stream_ = CreateZstdSourceStreamWithDictionary( |
| std::move(header_checker_source_stream), |
| shared_dictionary_->data(), shared_dictionary_->size()); |
| } |
| |
| UMA_HISTOGRAM_ENUMERATION("Network.SharedDictionary.EncodingType", |
| shared_dictionary_encoding_type_); |
| } |
| // When NET_DISABLE_BROTLI or NET_DISABLE_ZSTD is set, |
| // `shared_compression_stream_` can be null. |
| if (!shared_compression_stream_) { |
| return ERR_CONTENT_DECODING_FAILED; |
| } |
| return shared_compression_stream_->Read(buf, buf_len, |
| std::move(callback)); |
| case DictionaryStatus::kFailed: |
| return ERR_DICTIONARY_LOAD_FAILED; |
| } |
| } |
| |
| void SharedDictionaryNetworkTransaction::StopCaching() { |
| network_transaction_->StopCaching(); |
| } |
| |
| int64_t SharedDictionaryNetworkTransaction::GetTotalReceivedBytes() const { |
| return network_transaction_->GetTotalReceivedBytes(); |
| } |
| |
| int64_t SharedDictionaryNetworkTransaction::GetTotalSentBytes() const { |
| return network_transaction_->GetTotalSentBytes(); |
| } |
| |
| int64_t SharedDictionaryNetworkTransaction::GetReceivedBodyBytes() const { |
| return network_transaction_->GetReceivedBodyBytes(); |
| } |
| |
| void SharedDictionaryNetworkTransaction::DoneReading() { |
| network_transaction_->DoneReading(); |
| } |
| |
| const HttpResponseInfo* SharedDictionaryNetworkTransaction::GetResponseInfo() |
| const { |
| if (shared_dictionary_used_response_info_) { |
| return shared_dictionary_used_response_info_.get(); |
| } |
| return network_transaction_->GetResponseInfo(); |
| } |
| |
| LoadState SharedDictionaryNetworkTransaction::GetLoadState() const { |
| return network_transaction_->GetLoadState(); |
| } |
| |
| bool SharedDictionaryNetworkTransaction::GetLoadTimingInfo( |
| LoadTimingInfo* load_timing_info) const { |
| return network_transaction_->GetLoadTimingInfo(load_timing_info); |
| } |
| |
| void SharedDictionaryNetworkTransaction::PopulateLoadTimingInternalInfo( |
| LoadTimingInternalInfo* load_timing_internal_info) const { |
| network_transaction_->PopulateLoadTimingInternalInfo( |
| load_timing_internal_info); |
| } |
| |
| bool SharedDictionaryNetworkTransaction::GetRemoteEndpoint( |
| IPEndPoint* endpoint) const { |
| return network_transaction_->GetRemoteEndpoint(endpoint); |
| } |
| |
| void SharedDictionaryNetworkTransaction::PopulateNetErrorDetails( |
| NetErrorDetails* details) const { |
| return network_transaction_->PopulateNetErrorDetails(details); |
| } |
| |
| void SharedDictionaryNetworkTransaction::SetPriority(RequestPriority priority) { |
| network_transaction_->SetPriority(priority); |
| } |
| |
| void SharedDictionaryNetworkTransaction:: |
| SetWebSocketHandshakeStreamCreateHelper( |
| WebSocketHandshakeStreamBase::CreateHelper* create_helper) { |
| network_transaction_->SetWebSocketHandshakeStreamCreateHelper(create_helper); |
| } |
| |
| void SharedDictionaryNetworkTransaction::SetRequestHeadersCallback( |
| RequestHeadersCallback callback) { |
| network_transaction_->SetRequestHeadersCallback(std::move(callback)); |
| } |
| |
| void SharedDictionaryNetworkTransaction::SetResponseHeadersCallback( |
| ResponseHeadersCallback callback) { |
| network_transaction_->SetResponseHeadersCallback(std::move(callback)); |
| } |
| |
| void SharedDictionaryNetworkTransaction::SetEarlyResponseHeadersCallback( |
| ResponseHeadersCallback callback) { |
| network_transaction_->SetEarlyResponseHeadersCallback(std::move(callback)); |
| } |
| |
| void SharedDictionaryNetworkTransaction::SetConnectedCallback( |
| const ConnectedCallback& callback) { |
| connected_callback_ = callback; |
| } |
| |
| void SharedDictionaryNetworkTransaction::SetModifyRequestHeadersCallback( |
| base::RepeatingCallback<void(HttpRequestHeaders*)> callback) { |
| // This method should not be called for this class. |
| NOTREACHED(); |
| } |
| |
| void SharedDictionaryNetworkTransaction:: |
| SetIsSharedDictionaryReadAllowedCallback( |
| base::RepeatingCallback<bool()> callback) { |
| is_shared_dictionary_read_allowed_callback_ = std::move(callback); |
| } |
| |
| ConnectionAttempts SharedDictionaryNetworkTransaction::GetConnectionAttempts() |
| const { |
| return network_transaction_->GetConnectionAttempts(); |
| } |
| |
| void SharedDictionaryNetworkTransaction::CloseConnectionOnDestruction() { |
| network_transaction_->CloseConnectionOnDestruction(); |
| } |
| |
| bool SharedDictionaryNetworkTransaction::IsMdlMatchForMetrics() const { |
| return network_transaction_->IsMdlMatchForMetrics(); |
| } |
| |
| int SharedDictionaryNetworkTransaction::OnConnected( |
| const TransportInfo& info, |
| CompletionOnceCallback callback) { |
| cert_is_issued_by_known_root_ = info.cert_is_issued_by_known_root; |
| negotiated_protocol_ = info.negotiated_protocol; |
| |
| if (connected_callback_) { |
| return connected_callback_.Run(info, std::move(callback)); |
| } |
| return OK; |
| } |
| |
| } // namespace net |