blob: cdf0c0a85c89d42145345007648ebb0b544115b6 [file]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/legion/client_impl.h"
#include <utility>
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/logging.h"
#include "base/metrics/histogram_functions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_runner.h"
#include "base/time/time.h"
#include "components/legion/proto/legion.pb.h"
#include "components/legion/proto_utils/generate_content_response_utils.h"
namespace legion {
namespace {
void OnGenerateContentRequestCompleted(
Client::OnTextRequestCompletedCallback cb,
base::expected<proto::GenerateContentResponse, ErrorCode> result) {
if (!result.has_value()) {
std::move(cb).Run(base::unexpected(result.error()));
return;
}
auto text = ConvertGenerateContentResponseToText(*result);
if (!text.has_value()) {
LOG(ERROR) << "GenerateContentResponse did not contain any content";
std::move(cb).Run(base::unexpected(ErrorCode::kNoContent));
return;
}
std::move(cb).Run(text.value());
}
void OnRequestSent(
Client::OnGenerateContentRequestCompletedCallback cb,
base::expected<ClientImpl::BinaryEncodedProtoResponse, ErrorCode> result) {
if (!result.has_value()) {
std::move(cb).Run(base::unexpected(result.error()));
return;
}
proto::LegionResponse legion_response;
if (!legion_response.ParseFromArray(result->data(), result->size())) {
LOG(ERROR) << "Failed to parse LegionResponse";
std::move(cb).Run(base::unexpected(ErrorCode::kResponseParseError));
return;
}
if (!legion_response.has_generate_content_response()) {
LOG(ERROR) << "LegionResponse did not contain a "
"generate_content_response";
std::move(cb).Run(base::unexpected(ErrorCode::kNoResponse));
return;
}
std::move(cb).Run(legion_response.generate_content_response());
}
} // namespace
ClientImpl::ClientImpl(SecureChannelFactory channel_factory)
: secure_channel_factory_(std::move(channel_factory)) {}
ClientImpl::~ClientImpl() = default;
void ClientImpl::EstablishSession(
OnEstablishSessionCompletedCallback callback) {
GetOrCreateSecureChannel()->EstablishChannel(
base::BindOnce(&ClientImpl::OnSessionEstablished,
weak_factory_.GetWeakPtr(), std::move(callback)));
}
SecureChannel* ClientImpl::GetOrCreateSecureChannel() {
if (!secure_channel_) {
secure_channel_ = secure_channel_factory_.Run();
secure_channel_->SetResponseCallback(base::BindRepeating(
&ClientImpl::OnResponseReceived, base::Unretained(this)));
}
return secure_channel_.get();
}
void ClientImpl::SendRequest(int32_t request_id,
BinaryEncodedProtoRequest request,
OnRequestCompletedCallback callback,
base::TimeDelta timeout) {
DVLOG(1) << "SendRequest started.";
// Records the request size in bytes. The max value is 1M bytes.
base::UmaHistogramCounts1M("Legion.Client.RequestSize", request.size());
auto wrapped_callback = base::BindOnce(
&ClientImpl::OnRequestCompleted, weak_factory_.GetWeakPtr(),
std::move(callback), base::TimeTicks::Now());
if (GetOrCreateSecureChannel()->Write(std::move(request))) {
pending_requests_.emplace(request_id, std::move(wrapped_callback));
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&ClientImpl::OnRequestTimeout,
weak_factory_.GetWeakPtr(), request_id),
timeout);
} else {
// The channel is in a permanent failure state, so fail the current request.
DVLOG(1) << "Secure channel write failed.";
std::move(wrapped_callback).Run(base::unexpected(ErrorCode::kError));
}
}
void ClientImpl::SendTextRequest(proto::FeatureName feature_name,
const std::string& text,
OnTextRequestCompletedCallback callback,
const RequestOptions& options) {
proto::GenerateContentRequest request;
if (feature_name ==
proto::FeatureName::FEATURE_NAME_DEMO_GEMINI_GENERATE_CONTENT) {
request.set_model("dev_v3xs");
}
auto* content = request.add_contents();
content->set_role("user");
auto* part = content->add_parts();
part->set_text(text);
auto text_response_callback =
base::BindOnce(&OnGenerateContentRequestCompleted, std::move(callback));
SendGenerateContentRequest(feature_name, request,
std::move(text_response_callback), options);
}
void ClientImpl::SendGenerateContentRequest(
proto::FeatureName feature_name,
const proto::GenerateContentRequest& request,
OnGenerateContentRequestCompletedCallback callback,
const RequestOptions& options) {
int32_t request_id = next_request_id_;
next_request_id_++;
proto::LegionRequest request_proto;
request_proto.set_feature_name(feature_name);
request_proto.set_request_id(request_id);
*request_proto.mutable_generate_content_request() = request;
base::UmaHistogramSparse("Legion.Client.FeatureName",
static_cast<int>(feature_name));
std::string serialized_request;
request_proto.SerializeToString(&serialized_request);
BinaryEncodedProtoRequest binary_encoded_proto_request(
serialized_request.begin(), serialized_request.end());
// The callback for when the response is received.
auto response_parsing_callback =
base::BindOnce(&OnRequestSent, std::move(callback));
SendRequest(request_id, std::move(binary_encoded_proto_request),
std::move(response_parsing_callback), options.timeout);
}
void ClientImpl::FailAllPendingRequests(ErrorCode error_code) {
auto pending_requests = std::move(pending_requests_);
for (auto& entry : pending_requests) {
std::move(entry.second).Run(base::unexpected(error_code));
}
}
void ClientImpl::OnSessionEstablished(
OnEstablishSessionCompletedCallback callback,
base::expected<void, ErrorCode> result) {
std::move(callback).Run(std::move(result));
}
void ClientImpl::OnRequestTimeout(int32_t request_id) {
auto it = pending_requests_.find(request_id);
if (it != pending_requests_.end()) {
DLOG(ERROR) << "Request timed out: " << request_id;
timed_out_requests_.insert(request_id);
auto callback = std::move(it->second);
pending_requests_.erase(it);
std::move(callback).Run(base::unexpected(ErrorCode::kTimeout));
}
}
void ClientImpl::OnResponseReceived(
base::expected<BinaryEncodedProtoResponse, ErrorCode> result) {
if (!result.has_value()) {
// The secure channel is broken. Fail all pending requests and destroy the
// channel. It will be recreated on the next request.
DVLOG(1) << "Secure channel read failed. Destroying channel.";
FailAllPendingRequests(result.error());
secure_channel_.reset();
return;
}
proto::LegionResponse legion_response;
if (!legion_response.ParseFromArray(result->data(), result->size())) {
LOG(ERROR) << "Failed to parse LegionResponse";
// This is a protocol error. We don't know which request this response was
// for, so we fail all of them.
FailAllPendingRequests(ErrorCode::kResponseParseError);
return;
}
auto it = pending_requests_.find(legion_response.request_id());
if (it == pending_requests_.end()) {
auto timed_out_it = timed_out_requests_.find(legion_response.request_id());
if (timed_out_it != timed_out_requests_.end()) {
DLOG(ERROR) << "Received response for timed out request_id: "
<< legion_response.request_id();
timed_out_requests_.erase(timed_out_it);
} else {
DLOG(ERROR) << "Received response for unknown request_id: "
<< legion_response.request_id();
}
// This could be a response to a request that has already timed out and was
// removed from the pending list. In this case we should just ignore it and
// not cancel other pending requests.
return;
}
auto callback = std::move(it->second);
pending_requests_.erase(it);
std::move(callback).Run(std::move(result));
}
void ClientImpl::OnRequestCompleted(
OnRequestCompletedCallback callback,
base::TimeTicks start_time,
base::expected<BinaryEncodedProtoResponse, ErrorCode> result) {
const auto latency = base::TimeTicks::Now() - start_time;
if (result.has_value()) {
// Records the response size in bytes. The max value is 1M bytes.
base::UmaHistogramCounts1M("Legion.Client.ResponseSize.Success",
result->size());
base::UmaHistogramMediumTimes("Legion.Client.RequestLatency.Success",
latency);
} else if (result.error() == ErrorCode::kTimeout) {
base::UmaHistogramEnumeration("Legion.Client.RequestErrorCode",
ErrorCode::kTimeout);
base::UmaHistogramMediumTimes("Legion.Client.RequestLatency.Timeout",
latency);
} else {
base::UmaHistogramEnumeration("Legion.Client.RequestErrorCode",
result.error());
base::UmaHistogramMediumTimes("Legion.Client.RequestLatency.Error",
latency);
}
std::move(callback).Run(std::move(result));
}
} // namespace legion