blob: de54739843ef242a4c49e633f630f39591789f0f [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/history_embeddings/ml_answerer.h"
#include "base/memory/raw_ptr.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/model_execution/test/mock_on_device_capability.h"
#include "components/optimization_guide/core/model_quality/test_model_quality_logs_uploader_service.h"
#include "components/optimization_guide/core/optimization_guide_proto_util.h"
#include "components/optimization_guide/proto/features/history_answer.pb.h"
#include "components/prefs/testing_pref_service.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace history_embeddings {
namespace {
using ::base::test::TestFuture;
using ::optimization_guide::AnyWrapProto;
using ::optimization_guide::MockSession;
using ::optimization_guide::OptimizationGuideModelExecutionError;
using ::optimization_guide::OptimizationGuideModelStreamingExecutionResult;
using ::optimization_guide::proto::HistoryAnswerResponse;
using ::testing::_;
using ::testing::NiceMock;
using ::testing::StrictMock;
} // namespace
class MockModelExecutor : public optimization_guide::MockOnDeviceCapability {
public:
size_t GetCounter() { return counter_; }
void IncrementCounter() { counter_++; }
void Reset() { counter_ = 0; }
private:
size_t counter_ = 0;
};
class HistoryEmbeddingsMlAnswererTest : public testing::Test {
public:
void SetUp() override {
logs_uploader_ = std::make_unique<
optimization_guide::TestModelQualityLogsUploaderService>(&local_state_);
ml_answerer_ =
std::make_unique<MlAnswerer>(&model_executor_, logs_uploader_.get());
token_limits_ = {
.min_context_tokens = 1024,
};
ON_CALL(session_1_, GetTokenLimits())
.WillByDefault([&]() -> optimization_guide::TokenLimits& {
return GetTokenLimits();
});
ON_CALL(session_2_, GetTokenLimits())
.WillByDefault([&]() -> optimization_guide::TokenLimits& {
return GetTokenLimits();
});
}
void TearDown() override {
// Reset the logs uploader to avoid keeping a dangling pointer to the local
// state during destruction.
logs_uploader_ = nullptr;
}
optimization_guide::TokenLimits& GetTokenLimits() { return token_limits_; }
protected:
optimization_guide::StreamingResponse MakeResponse(
const std::string& answer_text,
bool is_complete = true) {
HistoryAnswerResponse answer_response;
answer_response.mutable_answer()->set_text(answer_text);
return optimization_guide::StreamingResponse{
.response = AnyWrapProto(answer_response),
.is_complete = is_complete,
};
}
base::test::TaskEnvironment task_environment_;
testing::NiceMock<MockModelExecutor> model_executor_;
std::unique_ptr<optimization_guide::TestModelQualityLogsUploaderService>
logs_uploader_;
TestingPrefServiceSimple local_state_;
std::unique_ptr<MlAnswerer> ml_answerer_;
testing::NiceMock<optimization_guide::MockSession> session_1_, session_2_;
optimization_guide::TokenLimits token_limits_;
};
TEST_F(HistoryEmbeddingsMlAnswererTest, ComputeAnswerNoSession) {
ON_CALL(model_executor_, StartSession(_, _)).WillByDefault([&] {
return nullptr;
});
ON_CALL(session_1_, GetSizeInTokens(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) { std::move(callback).Run(100); }));
Answerer::Context context("1");
context.url_passages_map.insert({"url_1", {"passage_11", "passage_12"}});
base::test::TestFuture<AnswererResult> result_future;
ml_answerer_->ComputeAnswer("query", context, result_future.GetCallback());
AnswererResult result = result_future.Take();
EXPECT_EQ(ComputeAnswerStatus::kModelUnavailable, result.status);
}
#if !BUILDFLAG(IS_FUCHSIA)
TEST_F(HistoryEmbeddingsMlAnswererTest, ComputeAnswerExecutionFailure) {
ON_CALL(model_executor_, StartSession(_, _)).WillByDefault([&] {
return std::make_unique<NiceMock<MockSession>>(&session_1_);
});
ON_CALL(session_1_, GetSizeInTokens(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) { std::move(callback).Run(100); }));
ON_CALL(session_1_, Score(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelScoreCallback
callback) { std::move(callback).Run(0.6); }));
ON_CALL(session_1_, ExecuteModel(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::
OptimizationGuideModelExecutionResultStreamingCallback
callback) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(
std::move(callback),
OptimizationGuideModelStreamingExecutionResult(
base::unexpected(
OptimizationGuideModelExecutionError::
FromModelExecutionError(
OptimizationGuideModelExecutionError::
ModelExecutionError::kGenericFailure)),
/*provided_by_on_device=*/true, nullptr)));
}));
Answerer::Context context("1");
context.url_passages_map.insert({"url_1", {"passage_11", "passage_12"}});
base::test::TestFuture<AnswererResult> result_future;
ml_answerer_->ComputeAnswer("query", context, result_future.GetCallback());
AnswererResult result = result_future.Take();
EXPECT_EQ(ComputeAnswerStatus::kExecutionFailure, result.status);
}
#endif
TEST_F(HistoryEmbeddingsMlAnswererTest, ComputeAnswerSingleUrl) {
ON_CALL(model_executor_, StartSession(_, _)).WillByDefault([&] {
return std::make_unique<NiceMock<MockSession>>(&session_1_);
});
ON_CALL(session_1_, GetSizeInTokens(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) { std::move(callback).Run(100); }));
ON_CALL(session_1_, Score(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelScoreCallback
callback) { std::move(callback).Run(0.6); }));
ON_CALL(session_1_, ExecuteModel(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::
OptimizationGuideModelExecutionResultStreamingCallback
callback) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback),
OptimizationGuideModelStreamingExecutionResult(
base::ok(MakeResponse("Answer_1")),
/*provided_by_on_device=*/true, nullptr)));
}));
Answerer::Context context("1");
context.url_passages_map.insert({"url_1", {"passage_11", "passage_12"}});
base::test::TestFuture<AnswererResult> result_future;
ml_answerer_->ComputeAnswer("query", context, result_future.GetCallback());
AnswererResult answer_result = result_future.Take();
EXPECT_EQ("Answer_1", answer_result.answer.text());
EXPECT_EQ("url_1", answer_result.url);
}
TEST_F(HistoryEmbeddingsMlAnswererTest, ComputeAnswerMultipleUrls) {
ON_CALL(model_executor_, StartSession(_, _))
.WillByDefault([&]() -> std::unique_ptr<MockSession> {
if (model_executor_.GetCounter() == 0) {
model_executor_.IncrementCounter();
return std::make_unique<NiceMock<MockSession>>(&session_1_);
} else if (model_executor_.GetCounter() == 1) {
model_executor_.IncrementCounter();
return std::make_unique<NiceMock<MockSession>>(&session_2_);
}
return std::unique_ptr<StrictMock<MockSession>>();
});
ON_CALL(session_1_, GetSizeInTokens(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) { std::move(callback).Run(100); }));
ON_CALL(session_2_, GetSizeInTokens(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) { std::move(callback).Run(100); }));
ON_CALL(session_1_, Score(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelScoreCallback
callback) { std::move(callback).Run(0.6); }));
ON_CALL(session_2_, Score(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelScoreCallback
callback) {
std::move(callback).Run(0.9);
})); // Speculative decoding should continue with this session.
ON_CALL(session_2_, ExecuteModel(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::
OptimizationGuideModelExecutionResultStreamingCallback
callback) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback),
OptimizationGuideModelStreamingExecutionResult(
base::ok(MakeResponse("Answer_2")),
/*provided_by_on_device=*/true, nullptr)));
}));
Answerer::Context context("1");
context.url_passages_map.insert({"url_1", {"passage_11", "passage_12"}});
context.url_passages_map.insert({"url_2", {"passage_21", "passage_22"}});
base::test::TestFuture<AnswererResult> result_future;
ml_answerer_->ComputeAnswer("query", context, result_future.GetCallback());
AnswererResult answer_result = result_future.Take();
// session_2_.Score() returns a higher score, so we'll get the result
// from session_2_.ExecuteModel().
// model_executor_.StartSession() returns session_1_ the first time it's
// called, and session_2_ the second time. StartSession() is called once
// per item in url_passages_map, but url_passages_map has nondeterministic
// iteration order since it's an unordered_map. So the url could be either
// url_1 or url_2, depending on internal unordered_map state.
EXPECT_EQ("Answer_2", answer_result.answer.text());
EXPECT_TRUE(answer_result.url == "url_1" || answer_result.url == "url_2");
}
TEST_F(HistoryEmbeddingsMlAnswererTest, ComputeAnswerUnanswerable) {
ON_CALL(model_executor_, StartSession(_, _)).WillByDefault([&] {
return std::make_unique<NiceMock<MockSession>>(&session_1_);
});
ON_CALL(session_1_, GetSizeInTokens(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) { std::move(callback).Run(100); }));
// Below the default 0.5 threshold.
ON_CALL(session_1_, Score(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelScoreCallback
callback) { std::move(callback).Run(0.3); }));
Answerer::Context context("1");
context.url_passages_map.insert({"url_1", {"passage_11", "passage_12"}});
TestFuture<AnswererResult> future;
ml_answerer_->ComputeAnswer("query", context, future.GetCallback());
const auto answer_result = future.Take();
EXPECT_EQ(ComputeAnswerStatus::kUnanswerable, answer_result.status);
}
TEST_F(HistoryEmbeddingsMlAnswererTest, ComputeAnswerNullScores) {
ON_CALL(model_executor_, StartSession(_, _)).WillByDefault([&] {
return std::make_unique<NiceMock<MockSession>>(&session_1_);
});
ON_CALL(session_1_, GetSizeInTokens(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) { std::move(callback).Run(100); }));
// Null score
ON_CALL(session_1_, Score(_, _))
.WillByDefault(testing::WithArg<1>(
[&](optimization_guide::OptimizationGuideModelScoreCallback
callback) { std::move(callback).Run(std::nullopt); }));
Answerer::Context context("1");
context.url_passages_map.insert({"url_1", {"passage_11", "passage_12"}});
TestFuture<AnswererResult> future;
ml_answerer_->ComputeAnswer("query", context, future.GetCallback());
const auto answer_result = future.Take();
EXPECT_EQ(ComputeAnswerStatus::kExecutionFailure, answer_result.status);
}
} // namespace history_embeddings