blob: fd69063dde5f8fbe5f3942c5761b4b37305ecc80 [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/legion/websocket_client.h"
#include <limits>
#include <utility>
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
#include "base/json/json_reader.h"
#include "base/json/json_writer.h"
#include "base/task/sequenced_task_runner.h"
#include "base/types/expected.h"
#include "components/legion/session_deserializer.h"
#include "net/http/http_request_headers.h"
#include "net/storage_access_api/status.h"
#include "net/traffic_annotation/network_traffic_annotation.h"
#include "services/network/public/mojom/network_context.mojom.h"
#include "third_party/oak/chromium/proto/session/session.pb.h"
#include "third_party/oak/chromium/proto/session/session.to_value.h"
namespace legion {
namespace {
constexpr size_t kMaxIncomingMessageSize = 1 << 20;
constexpr net::NetworkTrafficAnnotationTag kTrafficAnnotation =
net::DefineNetworkTrafficAnnotation("legion_client", R"(
semantics {
sender: "Legion Client"
description:
"This traffic creates an encrypted session with the "
"legion service and carries the request and response over that "
"session. "
"The feature is under development and behind a feature flag."
trigger:
"A feature that uses the legion component is triggered. "
"The feature determines which data to send."
user_data {
type: PROFILE_DATA
}
data: "This contains an encrypted request."
internal {
contacts {
email: "dullweber@chromium.org"
}
}
destination: GOOGLE_OWNED_SERVICE
last_reviewed: "2025-09-12"
}
policy {
cookies_allowed: NO
setting:
"Still in development. Setting may be added later."
policy_exception_justification:
"Still in development. Policy may be added later."
})");
} // namespace
WebSocketClient::WebSocketClient(const GURL& service_url,
NetworkContextFactory network_context_factory)
: service_url_(service_url),
network_context_factory_(std::move(network_context_factory)),
readable_watcher_(FROM_HERE, mojo::SimpleWatcher::ArmingPolicy::MANUAL) {}
WebSocketClient::~WebSocketClient() = default;
void WebSocketClient::Send(const oak::session::v1::SessionRequest& request,
ResponseCallback callback) {
DCHECK(!response_callback_);
response_callback_ = std::move(callback);
base::Value value = oak::session::v1::Serialize(request);
std::string json;
base::JSONWriter::Write(value, &json);
Send(Request(json.begin(), json.end()));
}
void WebSocketClient::Send(Request request) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (state_ == State::kDisconnected ||
request.size() > std::numeric_limits<uint32_t>::max()) {
ClosePipe(TransportError::kError);
return;
}
if (state_ == State::kInitialized) {
Connect();
}
if (state_ != State::kOpen) {
pending_write_data_.push(std::move(request));
return;
}
InternalWrite(request);
}
void WebSocketClient::OnResponse(
base::expected<std::vector<uint8_t>, TransportError> response) {
if (!response_callback_) {
return;
}
if (response.has_value()) {
auto value = base::JSONReader::Read(
std::string(response->begin(), response->end()), base::JSON_PARSE_RFC);
if (value) {
oak::session::v1::SessionResponse session_response;
if (DeserializeSessionResponse(*value, &session_response)) {
std::move(response_callback_)
.Run(base::ok(std::move(session_response)));
return;
}
}
std::move(response_callback_)
.Run(base::unexpected(TransportError::kDeserializationError));
return;
}
std::move(response_callback_).Run(base::unexpected(response.error()));
}
void WebSocketClient::Connect() {
// A disconnect handler is used so that the request can be completed in the
// event of an unexpected disconnection from the network service.
auto handshake_remote = handshake_receiver_.BindNewPipeAndPassRemote();
// base::Unretained(this) is safe because client_receiver_ is owned by
// |this|.
handshake_receiver_.set_disconnect_handler(base::BindOnce(
&WebSocketClient::OnMojoPipeDisconnect, base::Unretained(this)));
state_ = State::kConnecting;
std::vector<std::string> requested_protocols;
std::vector<network::mojom::HttpHeaderPtr> additional_headers;
network_context_factory_.Run()->CreateWebSocket(
service_url_, requested_protocols, net::SiteForCookies(),
net::StorageAccessApiStatus::kNone,
net::IsolationInfo::CreateForInternalRequest(
url::Origin::Create(service_url_)),
std::move(additional_headers), network::mojom::kBrowserProcessId,
url::Origin::Create(service_url_),
network::mojom::ClientSecurityState::New(),
network::mojom::kWebSocketOptionBlockAllCookies,
net::MutableNetworkTrafficAnnotationTag(kTrafficAnnotation),
std::move(handshake_remote),
/*url_loader_network_observer=*/mojo::NullRemote(),
/*auth_handler=*/mojo::NullRemote(),
/*header_client=*/mojo::NullRemote(),
/*throttling_profile_id=*/std::nullopt);
}
void WebSocketClient::InternalWrite(base::span<const uint8_t> data) {
CHECK(state_ == State::kOpen);
websocket_->SendMessage(network::mojom::WebSocketMessageType::TEXT,
data.size());
MojoResult result = writable_->WriteAllData(data);
if (result != MOJO_RESULT_OK) {
LOG(ERROR) << "Failed to write to WebSocket.";
ClosePipe(TransportError::kError);
}
}
void WebSocketClient::OnOpeningHandshakeStarted(
network::mojom::WebSocketHandshakeRequestPtr request) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
void WebSocketClient::OnFailure(const std::string& message,
int net_error,
int response_code) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
LOG(ERROR) << "Legion service connection failed " << message << ", "
<< net_error << ", " << response_code;
ClosePipe(TransportError::kError);
}
void WebSocketClient::OnConnectionEstablished(
mojo::PendingRemote<network::mojom::WebSocket> socket,
mojo::PendingReceiver<network::mojom::WebSocketClient> client_receiver,
network::mojom::WebSocketHandshakeResponsePtr response,
mojo::ScopedDataPipeConsumerHandle readable,
mojo::ScopedDataPipeProducerHandle writable) {
CHECK(!websocket_.is_bound());
CHECK(state_ == State::kConnecting);
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
websocket_.Bind(std::move(socket));
readable_ = std::move(readable);
// base::Unretained(this) is safe because readable_watcher_ is owned by
// |this|.
CHECK_EQ(readable_watcher_.Watch(
readable_.get(), MOJO_HANDLE_SIGNAL_READABLE,
MOJO_TRIGGER_CONDITION_SIGNALS_SATISFIED,
base::BindRepeating(&WebSocketClient::ReadFromDataPipe,
base::Unretained(this))),
MOJO_RESULT_OK);
writable_ = std::move(writable);
client_receiver_.Bind(std::move(client_receiver));
// |handshake_receiver_| will disconnect soon. In order to catch network
// process crashes, we switch to watching |client_receiver_|.
handshake_receiver_.set_disconnect_handler(base::DoNothing());
// base::Unretained(this) is safe because client_receiver_ is owned by
// |this|.
client_receiver_.set_disconnect_handler(base::BindOnce(
&WebSocketClient::OnMojoPipeDisconnect, base::Unretained(this)));
websocket_->StartReceiving();
state_ = State::kOpen;
while (!pending_write_data_.empty()) {
InternalWrite(pending_write_data_.front());
// Writing might fail which will close the socket.
if (state_ != State::kOpen) {
return;
}
pending_write_data_.pop();
}
}
void WebSocketClient::OnDataFrame(bool finish,
network::mojom::WebSocketMessageType type,
uint64_t data_len) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CHECK_EQ(state_, State::kOpen);
CHECK_EQ(pending_read_data_index_, pending_read_data_.size());
CHECK(!pending_read_finished_);
if (data_len == 0) {
if (finish) {
ProcessCompletedResponse();
}
return;
}
const size_t old_size = pending_read_data_index_;
const size_t new_size = old_size + data_len;
if ((type != network::mojom::WebSocketMessageType::BINARY &&
type != network::mojom::WebSocketMessageType::CONTINUATION) ||
data_len > std::numeric_limits<uint32_t>::max() || new_size < old_size ||
new_size > kMaxIncomingMessageSize) {
LOG(ERROR) << "Invalid WebSocket frame (type: " << static_cast<int>(type)
<< ", len: " << data_len << ")";
ClosePipe(TransportError::kError);
return;
}
pending_read_data_.resize(new_size);
pending_read_finished_ = finish;
client_receiver_.Pause();
ReadFromDataPipe(MOJO_RESULT_OK, mojo::HandleSignalsState());
}
void WebSocketClient::OnDropChannel(bool was_clean,
uint16_t code,
const std::string& reason) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CHECK(state_ == State::kOpen || state_ == State::kConnecting);
LOG(ERROR) << "OnDropChannel: " << code << " - " << reason;
ClosePipe(TransportError::kSocketClosed);
}
void WebSocketClient::OnClosingHandshake() {}
void WebSocketClient::ReadFromDataPipe(MojoResult,
const mojo::HandleSignalsState&) {
CHECK_LT(pending_read_data_index_, pending_read_data_.size());
size_t actually_read_bytes = 0;
const MojoResult result = readable_->ReadData(
MOJO_READ_DATA_FLAG_NONE,
base::span(pending_read_data_).subspan(pending_read_data_index_),
actually_read_bytes);
if (result == MOJO_RESULT_OK) {
pending_read_data_index_ += actually_read_bytes;
DCHECK_LE(pending_read_data_index_, pending_read_data_.size());
if (pending_read_data_index_ < pending_read_data_.size()) {
readable_watcher_.ArmOrNotify();
} else {
client_receiver_.Resume();
if (pending_read_finished_) {
ProcessCompletedResponse();
}
}
} else if (result == MOJO_RESULT_SHOULD_WAIT) {
readable_watcher_.ArmOrNotify();
} else {
LOG(ERROR) << "Reading WebSocket frame failed: "
<< static_cast<int>(result);
ClosePipe(TransportError::kError);
}
}
void WebSocketClient::ProcessCompletedResponse() {
std::vector<uint8_t> pending_read_data;
pending_read_data.swap(pending_read_data_);
pending_read_data_index_ = 0;
pending_read_finished_ = false;
// Call OnResponse asynchronously since this object may be destroyed during
// the callback.
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(&WebSocketClient::OnResponse,
weak_ptr_factory_.GetWeakPtr(),
base::ok(std::move(pending_read_data))));
}
void WebSocketClient::ClosePipe(TransportError status) {
if (state_ == State::kDisconnected) {
return;
}
state_ = State::kDisconnected;
client_receiver_.reset();
pending_write_data_ = {};
pending_read_data_index_ = 0;
pending_read_finished_ = false;
pending_read_data_.clear();
// Call OnResponse asynchronously since this object may be destroyed during
// the callback.
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&WebSocketClient::OnResponse,
weak_ptr_factory_.GetWeakPtr(), base::unexpected(status)));
}
void WebSocketClient::OnMojoPipeDisconnect() {
ClosePipe(TransportError::kSocketClosed);
}
} // namespace legion