blob: 00187088b33a63389ebf0d59f29ad0b241100469 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/history_embeddings/ml_intent_classifier.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/mock_optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_proto_util.h"
#include "components/optimization_guide/proto/features/history_query_intent.pb.h"
namespace history_embeddings {
namespace {
using optimization_guide::MockOptimizationGuideModelExecutor;
using optimization_guide::MockSession;
using optimization_guide::MockSessionWrapper;
using optimization_guide::
OptimizationGuideModelExecutionResultStreamingCallback;
using optimization_guide::StreamingResponse;
using ExecResult =
optimization_guide::OptimizationGuideModelStreamingExecutionResult;
using testing::_;
using testing::NiceMock;
using optimization_guide::proto::HistoryQueryIntentRequest;
using optimization_guide::proto::HistoryQueryIntentResponse;
auto FakeExecute(const google::protobuf::MessageLite& request_metadata) {
const HistoryQueryIntentRequest* request =
static_cast<const HistoryQueryIntentRequest*>(&request_metadata);
if (request->text().ends_with("!")) {
return MockSession::FailResult();
}
HistoryQueryIntentResponse response;
response.set_is_answer_seeking(request->text().ends_with("?"));
return MockSession::SuccessResult(optimization_guide::AnyWrapProto(response));
}
class MockClassifierSession : public MockSession {
public:
MockClassifierSession() {
ON_CALL(*this, ExecuteModel(_, _))
.WillByDefault(
[&](const google::protobuf::MessageLite& request_metadata,
OptimizationGuideModelExecutionResultStreamingCallback
callback) {
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback),
FakeExecute(request_metadata)));
});
}
};
class MockExecutor : public MockOptimizationGuideModelExecutor {
public:
MockExecutor() {
ON_CALL(*this, StartSession(_, _)).WillByDefault([&] {
return std::make_unique<optimization_guide::MockSessionWrapper>(
&session_);
});
}
NiceMock<MockClassifierSession> session_;
};
class HistoryEmbeddingsMlIntentClassifierTest : public testing::Test {
public:
void SetUp() override {}
protected:
base::test::TaskEnvironment task_environment_;
};
TEST_F(HistoryEmbeddingsMlIntentClassifierTest, IntentYes) {
NiceMock<MockExecutor> executor_;
MlIntentClassifier intent_classifier(&executor_);
{
base::test::TestFuture<ComputeIntentStatus, bool> future;
intent_classifier.ComputeQueryIntent("query?", future.GetCallback());
auto [status, is_query_answerable] = future.Take();
EXPECT_EQ(status, ComputeIntentStatus::SUCCESS);
EXPECT_EQ(is_query_answerable, true);
}
task_environment_.RunUntilIdle(); // Trigger DeleteSoon()
}
TEST_F(HistoryEmbeddingsMlIntentClassifierTest, IntentNo) {
NiceMock<MockExecutor> executor_;
MlIntentClassifier intent_classifier(&executor_);
{
base::test::TestFuture<ComputeIntentStatus, bool> future;
intent_classifier.ComputeQueryIntent("query", future.GetCallback());
auto [status, is_query_answerable] = future.Take();
EXPECT_EQ(status, ComputeIntentStatus::SUCCESS);
EXPECT_EQ(is_query_answerable, false);
}
task_environment_.RunUntilIdle(); // Trigger DeleteSoon()
}
TEST_F(HistoryEmbeddingsMlIntentClassifierTest, ExecutionFails) {
NiceMock<MockExecutor> executor_;
MlIntentClassifier intent_classifier(&executor_);
{
base::test::TestFuture<ComputeIntentStatus, bool> future;
intent_classifier.ComputeQueryIntent("query!", future.GetCallback());
auto [status, is_query_answerable] = future.Take();
EXPECT_EQ(status, ComputeIntentStatus::EXECUTION_FAILURE);
}
task_environment_.RunUntilIdle(); // Trigger DeleteSoon()
}
TEST_F(HistoryEmbeddingsMlIntentClassifierTest, FailToCreateSession) {
MockOptimizationGuideModelExecutor executor_;
EXPECT_CALL(executor_, StartSession(_, _)).WillRepeatedly([] {
return nullptr;
});
MlIntentClassifier intent_classifier(&executor_);
{
base::test::TestFuture<ComputeIntentStatus, bool> future;
intent_classifier.ComputeQueryIntent("query?", future.GetCallback());
auto [status, is_query_answerable] = future.Take();
EXPECT_EQ(status, ComputeIntentStatus::MODEL_UNAVAILABLE);
}
task_environment_.RunUntilIdle(); // Trigger DeleteSoon()
}
} // namespace
} // namespace history_embeddings