blob: a2754dbcd8d44e0e67f37f5e20fbd7924cb364a9 [file] [log] [blame]
// Copyright 2026 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/private_ai/connection_basic.h"
#include <utility>
#include "base/check.h"
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/functional/callback_helpers.h"
#include "base/logging.h"
#include "components/private_ai/connection.h"
#include "components/private_ai/proto/private_ai.pb.h"
namespace private_ai {
ConnectionBasic::ConnectionBasic(
std::unique_ptr<SecureChannel::Factory> secure_channel_factory,
base::OnceCallback<void(ErrorCode)> on_disconnect)
: on_disconnect_(std::move(on_disconnect)) {
CHECK(secure_channel_factory);
CHECK(on_disconnect_);
secure_channel_ = secure_channel_factory->Create(base::BindRepeating(
&ConnectionBasic::OnResponseReceived, weak_factory_.GetWeakPtr()));
CHECK(secure_channel_);
}
ConnectionBasic::~ConnectionBasic() = default;
void ConnectionBasic::Send(proto::PrivateAiRequest request,
base::TimeDelta timeout,
OnRequestCallback callback) {
// Indicates that `secure_channel_` is already closed.
if (!on_disconnect_) {
std::move(callback).Run(base::unexpected(ErrorCode::kError));
return;
}
const int32_t request_id = next_request_id_++;
request.set_request_id(request_id);
pending_request_callbacks_.emplace(request_id, std::move(callback));
std::string serialized_request;
request.SerializeToString(&serialized_request);
if (!secure_channel_->Write(
Request(serialized_request.begin(), serialized_request.end()))) {
// The channel is in a permanent failure state.
DVLOG(1) << "Secure channel write failed.";
CallOnDisconnect(ErrorCode::kError);
}
}
void ConnectionBasic::OnDestroy(ErrorCode error) {
on_disconnect_.Reset();
auto pending_requests = std::move(pending_request_callbacks_);
for (auto& pending_request : pending_requests) {
std::move(pending_request.second).Run(base::unexpected(error));
}
weak_factory_.InvalidateWeakPtrsAndDoom();
}
void ConnectionBasic::OnResponseReceived(
base::expected<Response, ErrorCode> result) {
if (!result.has_value()) {
DVLOG(1) << "Secure channel returned an error.";
CallOnDisconnect(result.error());
return;
}
proto::PrivateAiResponse private_ai_response;
if (!private_ai_response.ParseFromArray(result->data(), result->size())) {
LOG(ERROR) << "Failed to parse PrivateAiResponse";
// This is a protocol error. We don't know which request this response was
// for, so we fail all of them.
CallOnDisconnect(ErrorCode::kResponseParseError);
return;
}
auto it = pending_request_callbacks_.find(private_ai_response.request_id());
if (it == pending_request_callbacks_.end()) {
LOG(ERROR) << "Received PrivateAiResponse for unknown request_id: "
<< private_ai_response.request_id();
return;
}
auto callback = std::move(it->second);
pending_request_callbacks_.erase(it);
std::move(callback).Run(std::move(private_ai_response));
}
void ConnectionBasic::CallOnDisconnect(ErrorCode error_code) {
// First reject pending requests so that higher layers like
// ConnectionTokenAttestation can report the correct error code.
auto pending_requests = std::move(pending_request_callbacks_);
for (auto& pending_request : pending_requests) {
std::move(pending_request.second).Run(base::unexpected(error_code));
}
if (on_disconnect_) {
std::move(on_disconnect_).Run(error_code);
}
}
} // namespace private_ai