blob: b7e62be1b70ad741de1dbb30a2b59febb8efd327 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/omnibox/browser/on_device_tail_model_service.h"
#include "base/containers/flat_set.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/functional/bind.h"
#include "base/path_service.h"
#include "base/test/task_environment.h"
#include "components/omnibox/browser/on_device_tail_model_executor.h"
#include "components/optimization_guide/core/test_model_info_builder.h"
#include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using ::testing::ElementsAreArray;
namespace {
static const char kTailModelFilename[] = "test_tail_model.tflite";
static const char kVocabFilename[] = "vocab_test.txt";
constexpr int kNumLayer = 1;
constexpr int kStateSize = 512;
constexpr int kEmbeddingDim = 64;
} // namespace
class OnDeviceTailModelServiceTest : public ::testing::Test {
protected:
void SetUp() override {
test_model_provider_ = std::make_unique<
optimization_guide::TestOptimizationGuideModelProvider>();
service_ =
std::make_unique<OnDeviceTailModelService>(test_model_provider_.get());
optimization_guide::proto::Any any_metadata;
any_metadata.set_type_url(
"type.googleapis.com/com.foo.OnDeviceTailSuggestModelMetadata");
base::FilePath test_data_dir;
base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir);
test_data_dir = test_data_dir.AppendASCII("components/test/data/omnibox");
base::flat_set<base::FilePath> additional_files;
additional_files.insert(test_data_dir.AppendASCII(kVocabFilename));
optimization_guide::proto::OnDeviceTailSuggestModelMetadata metadata;
metadata.mutable_lstm_model_params()->set_num_layer(kNumLayer);
metadata.mutable_lstm_model_params()->set_state_size(kStateSize);
metadata.mutable_lstm_model_params()->set_embedding_dimension(
kEmbeddingDim);
metadata.SerializeToString(any_metadata.mutable_value());
model_info_ =
optimization_guide::TestModelInfoBuilder()
.SetModelFilePath(test_data_dir.AppendASCII(kTailModelFilename))
.SetAdditionalFiles(additional_files)
.SetVersion(123)
.SetModelMetadata(any_metadata)
.Build();
task_environment_.RunUntilIdle();
}
void TearDown() override {
service_ = nullptr;
task_environment_.RunUntilIdle();
}
bool IsExecutorReady() const {
return service_->tail_model_executor_->IsReady();
}
base::test::TaskEnvironment task_environment_;
std::unique_ptr<OnDeviceTailModelService> service_;
std::unique_ptr<optimization_guide::TestOptimizationGuideModelProvider>
test_model_provider_;
std::unique_ptr<optimization_guide::ModelInfo> model_info_;
};
TEST_F(OnDeviceTailModelServiceTest, OnModelUpdated) {
service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST,
*model_info_);
task_environment_.RunUntilIdle();
EXPECT_TRUE(IsExecutorReady());
}
TEST_F(OnDeviceTailModelServiceTest, GetPredictionsForInput) {
std::vector<OnDeviceTailModelExecutor::Prediction> results;
OnDeviceTailModelExecutor::ModelInput input("faceb", "", 5, 20, 0.05);
OnDeviceTailModelService::ResultCallback callback = base::BindOnce(
[](std::vector<OnDeviceTailModelExecutor::Prediction>* results,
std::vector<OnDeviceTailModelExecutor::Prediction> predictions) {
*results = std::move(predictions);
},
&results);
service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST,
*model_info_);
service_->GetPredictionsForInput(input, std::move(callback));
task_environment_.RunUntilIdle();
EXPECT_FALSE(results.empty());
EXPECT_TRUE(base::StartsWith(results[0].suggestion, "facebook",
base::CompareCase::SENSITIVE));
}
TEST_F(OnDeviceTailModelServiceTest, NullModelUpdate) {
service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST,
*model_info_);
task_environment_.RunUntilIdle();
EXPECT_TRUE(IsExecutorReady());
// Null model update shoud disable the executor.
service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST,
std::nullopt);
task_environment_.RunUntilIdle();
EXPECT_FALSE(IsExecutorReady());
}