blob: fb654286071c4c9df40f850a0dd296e28268dfa9 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/omnibox/browser/on_device_tail_model_service.h"
#include "base/containers/flat_set.h"
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/functional/bind.h"
#include "base/path_service.h"
#include "base/strings/string_util.h"
#include "base/test/task_environment.h"
#include "components/memory_pressure/fake_memory_pressure_monitor.h"
#include "components/omnibox/browser/on_device_tail_model_executor.h"
#include "components/optimization_guide/core/delivery/test_model_info_builder.h"
#include "components/optimization_guide/core/delivery/test_optimization_guide_model_provider.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using ::testing::ElementsAreArray;
namespace {
static const char kTailModelFilename[] = "test_tail_model.tflite";
static const char kVocabFilename[] = "vocab_test.txt";
constexpr int kNumLayer = 1;
constexpr int kStateSize = 512;
constexpr int kEmbeddingDim = 64;
constexpr int kMaxNumSteps = 20;
constexpr float kProbabilityThreshold = 0.01;
} // namespace
class OnDeviceTailModelServiceTest : public ::testing::Test {
protected:
void SetUp() override {
test_model_provider_ = std::make_unique<
optimization_guide::TestOptimizationGuideModelProvider>();
service_ =
std::make_unique<OnDeviceTailModelService>(test_model_provider_.get());
optimization_guide::proto::Any any_metadata;
any_metadata.set_type_url(
"type.googleapis.com/com.foo.OnDeviceTailSuggestModelMetadata");
base::FilePath test_data_dir;
base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir);
test_data_dir = test_data_dir.AppendASCII("components/test/data/omnibox");
base::flat_set<base::FilePath> additional_files;
additional_files.insert(test_data_dir.AppendASCII(kVocabFilename));
optimization_guide::proto::OnDeviceTailSuggestModelMetadata metadata;
metadata.mutable_lstm_model_params()->set_num_layer(kNumLayer);
metadata.mutable_lstm_model_params()->set_state_size(kStateSize);
metadata.mutable_lstm_model_params()->set_embedding_dimension(
kEmbeddingDim);
metadata.mutable_lstm_model_params()->set_max_num_steps(kMaxNumSteps);
metadata.mutable_lstm_model_params()->set_probability_threshold(
kProbabilityThreshold);
metadata.SerializeToString(any_metadata.mutable_value());
model_info_ =
optimization_guide::TestModelInfoBuilder()
.SetModelFilePath(test_data_dir.AppendASCII(kTailModelFilename))
.SetAdditionalFiles(additional_files)
.SetVersion(123)
.SetModelMetadata(any_metadata)
.Build();
task_environment_.RunUntilIdle();
}
void TearDown() override {
service_ = nullptr;
task_environment_.RunUntilIdle();
}
bool IsExecutorReady() const {
return service_->tail_model_executor_->IsReady();
}
base::test::TaskEnvironment task_environment_;
std::unique_ptr<OnDeviceTailModelService> service_;
std::unique_ptr<optimization_guide::TestOptimizationGuideModelProvider>
test_model_provider_;
std::unique_ptr<optimization_guide::ModelInfo> model_info_;
};
TEST_F(OnDeviceTailModelServiceTest, OnModelUpdated) {
service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST,
*model_info_);
task_environment_.RunUntilIdle();
EXPECT_TRUE(IsExecutorReady());
}
TEST_F(OnDeviceTailModelServiceTest, GetPredictionsForInput) {
std::vector<OnDeviceTailModelExecutor::Prediction> results;
OnDeviceTailModelExecutor::ModelInput input("faceb", "", 5);
OnDeviceTailModelService::ResultCallback callback = base::BindOnce(
[](std::vector<OnDeviceTailModelExecutor::Prediction>* results,
std::vector<OnDeviceTailModelExecutor::Prediction> predictions) {
*results = std::move(predictions);
},
&results);
service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST,
*model_info_);
service_->GetPredictionsForInput(input, std::move(callback));
task_environment_.RunUntilIdle();
EXPECT_FALSE(results.empty());
EXPECT_TRUE(base::StartsWith(results[0].suggestion, "facebook",
base::CompareCase::SENSITIVE));
}
TEST_F(OnDeviceTailModelServiceTest, NullModelUpdate) {
service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST,
*model_info_);
task_environment_.RunUntilIdle();
EXPECT_TRUE(IsExecutorReady());
// Null model update should disable the executor.
service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST,
std::nullopt);
task_environment_.RunUntilIdle();
EXPECT_FALSE(IsExecutorReady());
}
TEST_F(OnDeviceTailModelServiceTest, MemoryPressureLevel) {
service_->OnModelUpdated(
optimization_guide::proto::OptimizationTarget::
OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST,
*model_info_);
task_environment_.RunUntilIdle();
EXPECT_TRUE(IsExecutorReady());
OnDeviceTailModelExecutor::ModelInput input("faceb", "", 5);
std::vector<OnDeviceTailModelExecutor::Prediction> results;
memory_pressure::test::FakeMemoryPressureMonitor mem_pressure_monitor;
// The executor should be unloaded from memory when memory pressure level is
// critical.
std::vector<OnDeviceTailModelExecutor::Prediction> results_1;
OnDeviceTailModelService::ResultCallback callback_1 = base::BindOnce(
[](std::vector<OnDeviceTailModelExecutor::Prediction>* results,
std::vector<OnDeviceTailModelExecutor::Prediction> predictions) {
*results = std::move(predictions);
},
&results_1);
mem_pressure_monitor.SetAndNotifyMemoryPressure(
base::MemoryPressureListener::MemoryPressureLevel::
MEMORY_PRESSURE_LEVEL_CRITICAL);
service_->GetPredictionsForInput(input, std::move(callback_1));
task_environment_.RunUntilIdle();
EXPECT_FALSE(IsExecutorReady());
EXPECT_TRUE(results_1.empty());
// The executor should then be re-initialized once pressure level drops.
std::vector<OnDeviceTailModelExecutor::Prediction> results_2;
OnDeviceTailModelService::ResultCallback callback_2 = base::BindOnce(
[](std::vector<OnDeviceTailModelExecutor::Prediction>* results,
std::vector<OnDeviceTailModelExecutor::Prediction> predictions) {
*results = std::move(predictions);
},
&results_2);
mem_pressure_monitor.SetAndNotifyMemoryPressure(
base::MemoryPressureListener::MemoryPressureLevel::
MEMORY_PRESSURE_LEVEL_MODERATE);
service_->GetPredictionsForInput(input, std::move(callback_2));
task_environment_.RunUntilIdle();
EXPECT_TRUE(IsExecutorReady());
EXPECT_FALSE(results_2.empty());
}