blob: 094fc240cafad7003b035dc05b715e26ca3510d9 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/model_execution/model_execution_session.h"
#include <optional>
#include "base/functional/bind.h"
#include "components/optimization_guide/core/model_execution/optimization_guide_model_execution_error.h"
#include "components/optimization_guide/core/optimization_guide_enums.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/proto/common_types.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "third_party/blink/public/mojom/model_execution/model_session.mojom-shared.h"
using ModelExecutionError = optimization_guide::
OptimizationGuideModelExecutionError::ModelExecutionError;
ModelExecutionSession::ModelExecutionSession(
std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session>
session)
: session_(std::move(session)) {}
ModelExecutionSession::~ModelExecutionSession() = default;
blink::mojom::ModelStreamingResponseStatus ConvertModelExecutionError(
ModelExecutionError error) {
switch (error) {
case ModelExecutionError::kUnknown:
return blink::mojom::ModelStreamingResponseStatus::kErrorUnknown;
case ModelExecutionError::kInvalidRequest:
return blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest;
case ModelExecutionError::kRequestThrottled:
return blink::mojom::ModelStreamingResponseStatus::kErrorRequestThrottled;
case ModelExecutionError::kPermissionDenied:
return blink::mojom::ModelStreamingResponseStatus::kErrorPermissionDenied;
case ModelExecutionError::kGenericFailure:
return blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure;
case ModelExecutionError::kRetryableError:
return blink::mojom::ModelStreamingResponseStatus::kErrorRetryableError;
case ModelExecutionError::kNonRetryableError:
return blink::mojom::ModelStreamingResponseStatus::
kErrorNonRetryableError;
case ModelExecutionError::kUnsupportedLanguage:
return blink::mojom::ModelStreamingResponseStatus::
kErrorUnsupportedLanguage;
case ModelExecutionError::kFiltered:
return blink::mojom::ModelStreamingResponseStatus::kErrorFiltered;
case ModelExecutionError::kDisabled:
return blink::mojom::ModelStreamingResponseStatus::kErrorDisabled;
case ModelExecutionError::kCancelled:
return blink::mojom::ModelStreamingResponseStatus::kErrorCancelled;
}
}
void ModelExecutionSession::ModelExecutionCallback(
mojo::RemoteSetElementId responder_id,
optimization_guide::OptimizationGuideModelStreamingExecutionResult result) {
blink::mojom::ModelStreamingResponder* responder =
responder_set_.Get(responder_id);
if (!responder) {
return;
}
if (!result.response.has_value()) {
responder->OnResponse(
ConvertModelExecutionError(result.response.error().error()),
std::nullopt);
return;
}
auto response = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::StringValue>(result.response->response);
if (response->has_value()) {
responder->OnResponse(blink::mojom::ModelStreamingResponseStatus::kOngoing,
response->value());
}
if (result.response->is_complete) {
responder->OnResponse(blink::mojom::ModelStreamingResponseStatus::kComplete,
std::nullopt);
}
}
void ModelExecutionSession::Execute(
const std::string& input,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder> responder) {
mojo::RemoteSetElementId responder_id =
responder_set_.Add(std::move(responder));
optimization_guide::proto::StringValue request;
request.set_value(input);
session_->ExecuteModel(
request,
base::BindRepeating(&ModelExecutionSession::ModelExecutionCallback,
weak_ptr_factory_.GetWeakPtr(), responder_id));
}