| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/services/on_device_translation/translate_kit_client.h" |
| |
| #include <map> |
| #include <memory> |
| #include <optional> |
| #include <string> |
| #include <string_view> |
| |
| #include "base/check.h" |
| #include "base/compiler_specific.h" |
| #include "base/files/memory_mapped_file.h" |
| #include "base/logging.h" |
| #include "base/metrics/histogram_functions.h" |
| #include "base/no_destructor.h" |
| #include "base/strings/strcat.h" |
| #include "base/strings/string_number_conversions.h" |
| #include "base/strings/string_split.h" |
| #include "base/strings/string_util.h" |
| #include "base/types/expected.h" |
| #include "build/build_config.h" |
| #include "components/services/on_device_translation/proto/translate_kit_api.pb.h" |
| #include "components/services/on_device_translation/public/cpp/features.h" |
| #include "components/services/on_device_translation/translate_kit_structs.h" |
| |
| namespace on_device_translation { |
| namespace { |
| |
| using mojom::CreateTranslatorResult; |
| |
| // Logs UMA after an attempt to load the TranslateKit binary. |
| void LogLoadTranslateKitResult(LoadTranslateKitResult result, |
| const base::NativeLibraryLoadError* error) { |
| base::UmaHistogramEnumeration("AI.Translation.LoadTranslateKitResult", |
| result); |
| #if BUILDFLAG(IS_WIN) |
| if (result == LoadTranslateKitResult::kInvalidBinary) { |
| base::UmaHistogramSparse("AI.Translation.LoadTranslateKitErrorCode", |
| error->code); |
| } |
| #endif // BUILDFLAG(IS_WIN) |
| } |
| |
| // This method is used to receive the result from TranslatorTranslate() method. |
| void TranslateCallback(TranslateKitOutputText result, |
| std::uintptr_t user_data) { |
| std::string* output = reinterpret_cast<std::string*>(user_data); |
| CHECK(output); |
| CHECK(result.buffer); |
| *output = std::string(result.buffer, result.buffer_size); |
| } |
| |
| void DeleteReadOnlyMemoryRegion(std::uintptr_t memory_map_ptr, |
| std::uintptr_t user_data) { |
| CHECK(memory_map_ptr); |
| delete reinterpret_cast<base::MemoryMappedFile*>(memory_map_ptr); |
| } |
| |
| const void* ReadOnlyMemoryRegionData(std::uintptr_t memory_map_ptr, |
| std::uintptr_t user_data) { |
| CHECK(memory_map_ptr); |
| return reinterpret_cast<base::MemoryMappedFile*>(memory_map_ptr)->data(); |
| } |
| |
| uint64_t ReadOnlyMemoryRegionLength(std::uintptr_t memory_map_ptr, |
| std::uintptr_t user_data) { |
| CHECK(memory_map_ptr); |
| return reinterpret_cast<base::MemoryMappedFile*>(memory_map_ptr)->length(); |
| } |
| |
| void ParseFilePath(const char* file_name, |
| size_t file_name_size, |
| uint32_t& package_index, |
| base::FilePath& relative_path) { |
| std::string path(file_name, file_name_size); |
| // The TranslateKit only use ASCII paths. |
| CHECK(base::IsStringASCII(path)); |
| #if BUILDFLAG(IS_WIN) |
| base::ReplaceChars(path, "/", "\\", &path); |
| #endif // BUILDFLAG(IS_WIN) |
| base::FilePath virtual_path = base::FilePath::FromASCII(path); |
| // The TranslateKit doesn't use '..'. |
| CHECK(!virtual_path.ReferencesParent()); |
| // The TranslateKit must use an absolute path. |
| CHECK(virtual_path.IsAbsolute()); |
| const std::vector<base::FilePath::StringType> components = |
| virtual_path.GetComponents(); |
| #if BUILDFLAG(IS_WIN) |
| // Windows: "X:\0\bar" -> [ "X:", "\\", "0", "bar" ] |
| // ^^^ : component_idx = 2 |
| size_t component_idx = 2; |
| #else |
| // Posix: "/0/bar" -> [ "/", "0", "bar" ] |
| // ^^^ : component_idx = 1 |
| size_t component_idx = 1; |
| #endif // BUILDFLAG(IS_WIN) |
| CHECK_GT(components.size(), component_idx + 1); |
| CHECK(base::StringToUint(components[component_idx], &package_index)); |
| ++component_idx; |
| base::FilePath result; |
| for (; component_idx < components.size(); ++component_idx) { |
| result = result.Append(components[component_idx]); |
| } |
| relative_path = result; |
| } |
| |
| } // namespace |
| |
| // static |
| TranslateKitClient* TranslateKitClient::Get() { |
| static base::NoDestructor<std::unique_ptr<TranslateKitClient>> client( |
| std::make_unique<TranslateKitClient>( |
| GetTranslateKitBinaryPathFromCommandLine(), |
| base::PassKey<TranslateKitClient>())); |
| return client->get(); |
| } |
| |
| // static |
| std::unique_ptr<TranslateKitClient> TranslateKitClient::CreateForTest( |
| const base::FilePath& library_path) { |
| return std::make_unique<TranslateKitClient>( |
| library_path, base::PassKey<TranslateKitClient>()); |
| } |
| |
| TranslateKitClient::TranslateKitClient(const base::FilePath& library_path, |
| base::PassKey<TranslateKitClient>) |
| : lib_(library_path), |
| get_translate_kit_version_func_( |
| reinterpret_cast<GetTranslateKitVersionFn>( |
| lib_.GetFunctionPointer("GetTranslateKitVersion"))), |
| initialize_storage_backend_fnc_( |
| reinterpret_cast<InitializeStorageBackendFn>( |
| lib_.GetFunctionPointer("InitializeStorageBackend"))), |
| create_translate_kit_fnc_(reinterpret_cast<CreateTranslateKitFn>( |
| lib_.GetFunctionPointer("CreateTranslateKit"))), |
| delete_tanslate_kit_fnc_(reinterpret_cast<DeleteTranslateKitFn>( |
| lib_.GetFunctionPointer("DeleteTranslateKit"))), |
| set_language_packages_func_( |
| reinterpret_cast<TranslateKitSetLanguagePackagesFn>( |
| lib_.GetFunctionPointer("TranslateKitSetLanguagePackages"))), |
| translate_kit_create_translator_func_( |
| reinterpret_cast<TranslateKitCreateTranslatorFn>( |
| lib_.GetFunctionPointer("TranslateKitCreateTranslator"))), |
| delete_translator_fnc_(reinterpret_cast<DeleteTranslatorFn>( |
| lib_.GetFunctionPointer("DeleteTranslator"))), |
| translator_translate_func_(reinterpret_cast<TranslatorTranslateFn>( |
| lib_.GetFunctionPointer("TranslatorTranslate"))), |
| translate_kit_sentence_split_func_( |
| reinterpret_cast<TranslateKitSplitSentencesFn>( |
| lib_.GetFunctionPointer("TranslateKitSplitSentences"))) { |
| LogLoadTranslateKitResult(CheckLoadTranslateKitResult(), lib_.GetError()); |
| } |
| |
| LoadTranslateKitResult TranslateKitClient::CheckLoadTranslateKitResult() { |
| if (!lib_.is_valid()) { |
| maybe_kit_ptr_ = |
| base::unexpected(CreateTranslatorResult::kErrorInvalidBinary); |
| return LoadTranslateKitResult::kInvalidBinary; |
| } |
| |
| if (!get_translate_kit_version_func_ || !initialize_storage_backend_fnc_ || |
| !create_translate_kit_fnc_ || !delete_tanslate_kit_fnc_ || |
| !set_language_packages_func_ || !translate_kit_create_translator_func_ || |
| !delete_translator_fnc_ || !translator_translate_func_) { |
| maybe_kit_ptr_ = |
| base::unexpected(CreateTranslatorResult::kErrorInvalidFunctionPointer); |
| return LoadTranslateKitResult::kInvalidFunctionPointer; |
| } |
| |
| if (!IsTranslateKitVersionValid()) { |
| maybe_kit_ptr_ = |
| base::unexpected(CreateTranslatorResult::kErrorInvalidVersion); |
| return LoadTranslateKitResult::kInvalidVersion; |
| } |
| |
| return LoadTranslateKitResult::kSuccess; |
| } |
| |
| DISABLE_CFI_DLSYM |
| bool TranslateKitClient::IsTranslateKitVersionValid() { |
| std::string version_buffer(kTranslationAPILibraryVersionStringSize, '\0'); |
| TranslateKitVersion version{version_buffer.data(), version_buffer.size()}; |
| if (!get_translate_kit_version_func_(&version)) { |
| return false; |
| } |
| |
| std::string_view version_string(version.buffer, version.buffer_size); |
| return IsValidTranslateKitVersion(version_string); |
| } |
| |
| DISABLE_CFI_DLSYM |
| bool TranslateKitClient::MaybeInitialize() { |
| if (maybe_kit_ptr_.has_value() && *maybe_kit_ptr_) { |
| // Already successfully initialized. |
| return true; |
| } |
| if (!maybe_kit_ptr_.has_value()) { |
| // An error occurred while loading the TranslateKit binary or the previous |
| // initialization failed. |
| return false; |
| } |
| initialize_storage_backend_fnc_( |
| &TranslateKitClient::FileExists, |
| &TranslateKitClient::OpenForReadOnlyMemoryMap, |
| &DeleteReadOnlyMemoryRegion, &ReadOnlyMemoryRegionData, |
| &ReadOnlyMemoryRegionLength, reinterpret_cast<std::uintptr_t>(this)); |
| |
| maybe_kit_ptr_ = create_translate_kit_fnc_(); |
| if (!*maybe_kit_ptr_) { |
| maybe_kit_ptr_ = |
| base::unexpected(CreateTranslatorResult::kErrorFailedToInitialize); |
| return false; |
| } |
| return true; |
| } |
| |
| DISABLE_CFI_DLSYM |
| void TranslateKitClient::SetConfig( |
| mojom::OnDeviceTranslationServiceConfigPtr config) { |
| if (!MaybeInitialize()) { |
| return; |
| } |
| |
| // When `file_operation_proxy_` is set, need to reset `file_operation_proxy_` |
| // before binding the new one. This happens when SetConfig() is called again |
| // for the new config. |
| file_operation_proxy_.reset(); |
| file_operation_proxy_.Bind(std::move(config->file_operation_proxy)); |
| chrome::on_device_translation::TranslateKitLanguagePackageConfig config_proto; |
| size_t index = 0; |
| for (const auto& package : config->packages) { |
| // Generate a virtual absolute file path for the package. |
| // On Windows, set the package path to a fake drive letter 'X:' to avoid |
| // the file path validation in the TranslateKit. |
| const std::string package_path = |
| #if BUILDFLAG(IS_WIN) |
| base::StrCat({"X:\\", base::NumberToString(index++)}); |
| #else |
| base::StrCat({"/", base::NumberToString(index++)}); |
| #endif // BUILDFLAG(IS_WIN) |
| auto* new_package = config_proto.add_packages(); |
| new_package->set_language1(package->language1); |
| new_package->set_language2(package->language2); |
| new_package->set_package_path(package_path); |
| } |
| |
| const std::string packages_str = config_proto.SerializeAsString(); |
| CHECK(set_language_packages_func_( |
| *maybe_kit_ptr_, |
| TranslateKitSetLanguagePackagesArgs{packages_str.c_str(), |
| packages_str.size()})) |
| << "Failed to set config"; |
| } |
| |
| DISABLE_CFI_DLSYM |
| TranslateKitClient::~TranslateKitClient() { |
| if (!maybe_kit_ptr_.has_value() || !*maybe_kit_ptr_) { |
| return; |
| } |
| delete_tanslate_kit_fnc_(*maybe_kit_ptr_); |
| maybe_kit_ptr_ = 0; |
| translators_.clear(); |
| } |
| |
| bool TranslateKitClient::CanTranslate(const std::string& source_lang, |
| const std::string& target_lang) { |
| if (!maybe_kit_ptr_.has_value()) { |
| return false; |
| } |
| CHECK(*maybe_kit_ptr_) << "SetConfig must have been called"; |
| CHECK(file_operation_proxy_); |
| |
| return TranslateKitClient::GetTranslator(source_lang, target_lang) |
| .has_value(); |
| } |
| |
| base::expected<TranslateKitClient::Translator*, CreateTranslatorResult> |
| TranslateKitClient::GetTranslator(const std::string& source_lang, |
| const std::string& target_lang) { |
| if (!maybe_kit_ptr_.has_value()) { |
| CHECK_NE(maybe_kit_ptr_.error(), CreateTranslatorResult::kSuccess); |
| return base::unexpected(maybe_kit_ptr_.error()); |
| } |
| CHECK(*maybe_kit_ptr_) << "SetConfig must have been called"; |
| CHECK(file_operation_proxy_); |
| |
| TranslatorKey key(source_lang, target_lang); |
| if (auto it = translators_.find(key); it != translators_.end()) { |
| return it->second.get(); |
| } |
| auto translator = TranslatorImpl::MaybeCreate(this, source_lang, target_lang); |
| if (!translator) { |
| return base::unexpected( |
| CreateTranslatorResult::kErrorFailedToCreateTranslator); |
| } |
| auto raw_translator_ptr = translator.get(); |
| translators_.emplace(std::move(key), std::move(translator)); |
| return raw_translator_ptr; |
| } |
| |
| DISABLE_CFI_DLSYM |
| std::unique_ptr<TranslateKitClient::TranslatorImpl> |
| TranslateKitClient::TranslatorImpl::MaybeCreate( |
| TranslateKitClient* client, |
| const std::string& source_lang, |
| const std::string& target_lang) { |
| CHECK(client->translate_kit_create_translator_func_); |
| std::uintptr_t translator_ptr = client->translate_kit_create_translator_func_( |
| *client->maybe_kit_ptr_, |
| TranslateKitLanguage(source_lang.c_str(), source_lang.length()), |
| TranslateKitLanguage(target_lang.c_str(), target_lang.length())); |
| if (!translator_ptr) { |
| return nullptr; |
| } |
| return std::make_unique<TranslatorImpl>(base::PassKey<TranslatorImpl>(), |
| client, source_lang, translator_ptr); |
| } |
| |
| TranslateKitClient::TranslatorImpl::TranslatorImpl( |
| base::PassKey<TranslatorImpl>, |
| TranslateKitClient* client, |
| const std::string& source_lang, |
| std::uintptr_t translator_ptr) |
| : client_(client), |
| source_lang_(source_lang), |
| translator_ptr_(translator_ptr) {} |
| |
| DISABLE_CFI_DLSYM |
| TranslateKitClient::TranslatorImpl::~TranslatorImpl() { |
| CHECK(client_->delete_translator_fnc_); |
| client_->delete_translator_fnc_(translator_ptr_); |
| } |
| |
| DISABLE_CFI_DLSYM |
| std::optional<std::string> TranslateKitClient::TranslatorImpl::Translate( |
| const std::string& text) { |
| CHECK(client_->translator_translate_func_); |
| std::string output; |
| if (client_->translator_translate_func_( |
| translator_ptr_, TranslateKitInputText(text.c_str(), text.length()), |
| TranslateCallback, reinterpret_cast<uintptr_t>(&text))) { |
| return text; |
| } |
| return std::nullopt; |
| } |
| |
| DISABLE_CFI_DLSYM |
| std::vector<std::string> TranslateKitClient::TranslatorImpl::SplitSentences( |
| const std::string& text) { |
| std::vector<std::string> sentences; |
| if (!client_->translate_kit_sentence_split_func_) { |
| sentences.push_back(text); |
| return sentences; |
| } |
| |
| auto sentence_callback = [](TranslateKitOutputText result, |
| std::uintptr_t user_data) { |
| std::vector<std::string>* sentences = |
| reinterpret_cast<std::vector<std::string>*>(user_data); |
| CHECK(sentences); |
| CHECK(result.buffer); |
| std::string sentence(result.buffer, result.buffer_size); |
| sentences->push_back(std::move(sentence)); |
| }; |
| |
| if (!client_->translate_kit_sentence_split_func_( |
| TranslateKitInputText(text.c_str(), text.length()), |
| TranslateKitLanguage(source_lang_.c_str(), source_lang_.length()), |
| sentence_callback, reinterpret_cast<std::uintptr_t>(&sentences))) { |
| return std::vector<std::string>{text}; |
| } |
| |
| return sentences; |
| } |
| |
| // static |
| bool TranslateKitClient::FileExists(const char* file_name, |
| size_t file_name_size, |
| bool* is_directory, |
| std::uintptr_t user_data) { |
| CHECK(file_name); |
| CHECK(is_directory); |
| CHECK(user_data); |
| return reinterpret_cast<TranslateKitClient*>(user_data)->FileExistsImpl( |
| file_name, file_name_size, is_directory); |
| } |
| bool TranslateKitClient::FileExistsImpl(const char* file_name, |
| size_t file_name_size, |
| bool* is_directory) { |
| uint32_t package_index = 0; |
| base::FilePath relative_path; |
| ParseFilePath(file_name, file_name_size, package_index, relative_path); |
| bool exists = false; |
| CHECK(file_operation_proxy_); |
| file_operation_proxy_->FileExists(package_index, relative_path, &exists, |
| is_directory); |
| return exists; |
| } |
| |
| // static |
| std::uintptr_t TranslateKitClient::OpenForReadOnlyMemoryMap( |
| const char* file_name, |
| size_t file_name_size, |
| std::uintptr_t user_data) { |
| CHECK(file_name); |
| CHECK(user_data); |
| return reinterpret_cast<TranslateKitClient*>(user_data) |
| ->OpenForReadOnlyMemoryMapImpl(file_name, file_name_size); |
| } |
| |
| std::uintptr_t TranslateKitClient::OpenForReadOnlyMemoryMapImpl( |
| const char* file_name, |
| size_t file_name_size) { |
| uint32_t package_index = 0; |
| base::FilePath relative_path; |
| ParseFilePath(file_name, file_name_size, package_index, relative_path); |
| base::File file; |
| CHECK(file_operation_proxy_); |
| file_operation_proxy_->Open(package_index, relative_path, &file); |
| if (!file.IsValid()) { |
| return 0; |
| } |
| std::unique_ptr<base::MemoryMappedFile> mapped_file = |
| std::make_unique<base::MemoryMappedFile>(); |
| CHECK(mapped_file->Initialize(std::move(file))); |
| return reinterpret_cast<std::uintptr_t>(mapped_file.release()); |
| } |
| |
| } // namespace on_device_translation |