blob: ab6048b09fee89b435238a7b04dcd8ca5532de45 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/on_device_translation/translator.h"
#include <algorithm>
#include "base/feature_list.h"
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "chrome/browser/on_device_translation/pref_names.h"
#include "chrome/browser/on_device_translation/service_controller.h"
#include "chrome/browser/on_device_translation/translation_metrics.h"
#include "chrome/browser/profiles/profile.h"
#include "components/prefs/pref_service.h"
#include "components/services/on_device_translation/public/cpp/features.h"
#include "mojo/public/cpp/bindings/callback_helpers.h"
#include "mojo/public/cpp/bindings/remote_set.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
#include "third_party/blink/public/mojom/on_device_translation/translator.mojom.h"
#include "url/origin.h"
namespace on_device_translation {
namespace {
bool IsTranslatableCharacter(char character) {
return !base::IsAsciiWhitespace(character) &&
!base::IsAsciiControl(character);
}
bool ContainsTranslatableContent(const std::string& input) {
return std::any_of(input.begin(), input.end(), IsTranslatableCharacter);
}
} // namespace
Translator::Translator(
base::WeakPtr<content::BrowserContext> browser_context,
const std::string& source_lang,
const std::string& target_lang,
mojo::PendingRemote<on_device_translation::mojom::Translator> remote)
: browser_context_(browser_context),
source_lang_(source_lang),
target_lang_(target_lang),
translator_remote_(std::move(remote)) {}
Translator::~Translator() = default;
bool Translator::VerifyPrerequisites(
const std::string& input,
mojo::Remote<blink::mojom::ModelStreamingResponder>& responder) {
if (!Profile::FromBrowserContext(browser_context_.get())
->GetPrefs()
->GetBoolean(prefs::kTranslatorAPIAllowed)) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure,
/*quota_error_info=*/nullptr);
return false;
}
RecordTranslationAPICallForLanguagePair("Translate", source_lang_,
target_lang_);
RecordTranslationCharacterCount(source_lang_, target_lang_, input.size());
// https://github.com/webmachinelearning/translation-api/pull/38: "If |input|
// is the empty string, or otherwise consists of no translatable content
// (e.g., only contains whitespace, or control characters), then the resulting
// translation should be |input|. In such cases, |sourceLanguage| and
// |targetLanguage| should be ignored."
if (!ContainsTranslatableContent(input)) {
responder->OnStreaming(input);
responder->OnCompletion(/*context_info=*/nullptr);
return false;
}
return true;
}
void Translator::Translate(
const std::string& input,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
CHECK(browser_context_);
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
if (!VerifyPrerequisites(input, responder)) {
return;
}
if (translator_remote_.is_connected()) {
translator_remote_->Translate(
input,
mojo::WrapCallbackWithDefaultInvokeIfNotRun(
base::BindOnce(
[](mojo::Remote<blink::mojom::ModelStreamingResponder>
responder,
const std::optional<std::string>& output) {
if (!output) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::
kErrorGenericFailure,
/*quota_error_info=*/nullptr);
return;
}
responder->OnStreaming(output.value());
responder->OnCompletion(/*context_info=*/nullptr);
},
std::move(responder)),
std::nullopt));
} else {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure,
/*quota_error_info=*/nullptr);
}
}
void Translator::SplitSentencesCallback(
mojo::Remote<blink::mojom::ModelStreamingResponder> responder,
const std::vector<std::string>& sentences) {
if (sentences.empty()) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure,
/*quota_error_info=*/nullptr);
return;
}
mojo::RemoteSetElementId responder_id =
responder_set_.Add(std::move(responder));
pending_translations_[responder_id] = sentences.size();
const int total_translations = sentences.size();
for (const auto& sentence : sentences) {
translator_remote_->Translate(
sentence, base::BindOnce(&Translator::TranslateStreamingCallback,
weak_ptr_factory_.GetWeakPtr(), responder_id,
total_translations));
}
}
void Translator::TranslateStreamingCallback(
mojo::RemoteSetElementId responder_id,
int total_translations,
const std::optional<std::string>& output) {
auto it = pending_translations_.find(responder_id);
blink::mojom::ModelStreamingResponder* responder_ptr =
responder_set_.Get(responder_id);
// This should only happen after the responder disconnected.
if (it == pending_translations_.end()) {
CHECK(!responder_ptr);
return;
}
// This indicates that the responder disconnected.
if (!responder_ptr) {
pending_translations_.erase(it);
return;
}
if (!output) {
responder_ptr->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure,
/*quota_error_info=*/nullptr);
responder_set_.Remove(responder_id);
pending_translations_.erase(it);
return;
}
// Each Translate() call strips leading and trailing whitespace.
// If this is the first sentence we are streaming, do not prepend a space.
// Otherwise, prepend a space to the chunk.
const int pending_translations = it->second;
if (pending_translations == total_translations) {
responder_ptr->OnStreaming(output.value());
} else {
responder_ptr->OnStreaming(base::StrCat({" ", output.value()}));
}
if (--it->second == 0) {
responder_ptr->OnCompletion(/*context_info=*/nullptr);
responder_set_.Remove(responder_id);
pending_translations_.erase(it);
}
}
void Translator::TranslateStreaming(
const std::string& input,
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
pending_responder) {
CHECK(browser_context_);
if (!base::FeatureList::IsEnabled(kTranslateStreamingBySentence)) {
Translate(input, std::move(pending_responder));
return;
}
mojo::Remote<blink::mojom::ModelStreamingResponder> responder(
std::move(pending_responder));
if (!VerifyPrerequisites(input, responder)) {
return;
}
if (!translator_remote_.is_connected()) {
responder->OnError(
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure,
/*quota_error_info=*/nullptr);
return;
}
translator_remote_->SplitSentences(
input,
mojo::WrapCallbackWithDefaultInvokeIfNotRun(
base::BindOnce(&Translator::SplitSentencesCallback,
weak_ptr_factory_.GetWeakPtr(), std::move(responder)),
std::vector<std::string>()));
}
} // namespace on_device_translation