| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/browser/model_execution/model_execution_session.h" |
| |
| #include <optional> |
| |
| #include "base/functional/bind.h" |
| #include "components/optimization_guide/core/model_execution/optimization_guide_model_execution_error.h" |
| #include "components/optimization_guide/core/optimization_guide_enums.h" |
| #include "components/optimization_guide/core/optimization_guide_model_executor.h" |
| #include "components/optimization_guide/core/optimization_guide_util.h" |
| #include "components/optimization_guide/proto/common_types.pb.h" |
| #include "components/optimization_guide/proto/string_value.pb.h" |
| #include "mojo/public/cpp/bindings/remote_set.h" |
| #include "third_party/blink/public/mojom/model_execution/model_session.mojom-shared.h" |
| |
| using ModelExecutionError = optimization_guide:: |
| OptimizationGuideModelExecutionError::ModelExecutionError; |
| |
| ModelExecutionSession::ModelExecutionSession( |
| std::unique_ptr<optimization_guide::OptimizationGuideModelExecutor::Session> |
| session) |
| : session_(std::move(session)) {} |
| |
| ModelExecutionSession::~ModelExecutionSession() = default; |
| |
| blink::mojom::ModelStreamingResponseStatus ConvertModelExecutionError( |
| ModelExecutionError error) { |
| switch (error) { |
| case ModelExecutionError::kUnknown: |
| return blink::mojom::ModelStreamingResponseStatus::kErrorUnknown; |
| case ModelExecutionError::kInvalidRequest: |
| return blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest; |
| case ModelExecutionError::kRequestThrottled: |
| return blink::mojom::ModelStreamingResponseStatus::kErrorRequestThrottled; |
| case ModelExecutionError::kPermissionDenied: |
| return blink::mojom::ModelStreamingResponseStatus::kErrorPermissionDenied; |
| case ModelExecutionError::kGenericFailure: |
| return blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure; |
| case ModelExecutionError::kRetryableError: |
| return blink::mojom::ModelStreamingResponseStatus::kErrorRetryableError; |
| case ModelExecutionError::kNonRetryableError: |
| return blink::mojom::ModelStreamingResponseStatus:: |
| kErrorNonRetryableError; |
| case ModelExecutionError::kUnsupportedLanguage: |
| return blink::mojom::ModelStreamingResponseStatus:: |
| kErrorUnsupportedLanguage; |
| case ModelExecutionError::kFiltered: |
| return blink::mojom::ModelStreamingResponseStatus::kErrorFiltered; |
| case ModelExecutionError::kDisabled: |
| return blink::mojom::ModelStreamingResponseStatus::kErrorDisabled; |
| case ModelExecutionError::kCancelled: |
| return blink::mojom::ModelStreamingResponseStatus::kErrorCancelled; |
| } |
| } |
| |
| void ModelExecutionSession::ModelExecutionCallback( |
| mojo::RemoteSetElementId responder_id, |
| optimization_guide::OptimizationGuideModelStreamingExecutionResult result) { |
| blink::mojom::ModelStreamingResponder* responder = |
| responder_set_.Get(responder_id); |
| if (!responder) { |
| return; |
| } |
| |
| if (!result.response.has_value()) { |
| responder->OnResponse( |
| ConvertModelExecutionError(result.response.error().error()), |
| std::nullopt); |
| return; |
| } |
| |
| auto response = optimization_guide::ParsedAnyMetadata< |
| optimization_guide::proto::StringValue>(result.response->response); |
| if (response->has_value()) { |
| responder->OnResponse(blink::mojom::ModelStreamingResponseStatus::kOngoing, |
| response->value()); |
| } |
| if (result.response->is_complete) { |
| responder->OnResponse(blink::mojom::ModelStreamingResponseStatus::kComplete, |
| std::nullopt); |
| } |
| } |
| |
| void ModelExecutionSession::Execute( |
| const std::string& input, |
| mojo::PendingRemote<blink::mojom::ModelStreamingResponder> responder) { |
| mojo::RemoteSetElementId responder_id = |
| responder_set_.Add(std::move(responder)); |
| optimization_guide::proto::StringValue request; |
| request.set_value(input); |
| session_->ExecuteModel( |
| request, |
| base::BindRepeating(&ModelExecutionSession::ModelExecutionCallback, |
| weak_ptr_factory_.GetWeakPtr(), responder_id)); |
| } |