blob: cb78924a493c6f4abba89e42fb3090fc66be858e [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ash/input_method/suggestions_service_client.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/text_suggester.mojom.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace chromeos {
namespace {
using ::chromeos::ime::TextCompletionCandidate;
using ::chromeos::ime::TextSuggestion;
using ::chromeos::ime::TextSuggestionMode;
using ::chromeos::ime::TextSuggestionType;
class SuggestionsServiceClientTest : public testing::Test {
public:
SuggestionsServiceClientTest() = default;
private:
content::BrowserTaskEnvironment task_environment_;
};
machine_learning::mojom::TextSuggesterResultPtr GenerateMultiWordResult(
std::string text,
float score) {
auto result = machine_learning::mojom::TextSuggesterResult::New();
result->status = machine_learning::mojom::TextSuggesterResult::Status::OK;
auto multi_word = machine_learning::mojom::MultiWordSuggestionCandidate::New(
/*text=*/text, /*normalized_score=*/score);
auto candidate = machine_learning::mojom::TextSuggestionCandidate::New();
candidate->set_multi_word(std::move(multi_word));
result->candidates.emplace_back(std::move(candidate));
return result;
}
TEST_F(SuggestionsServiceClientTest, ReturnsCompletionResultsFromMojoService) {
machine_learning::FakeServiceConnectionImpl fake_service_connection;
machine_learning::ServiceConnection::UseFakeServiceConnectionForTesting(
&fake_service_connection);
machine_learning::ServiceConnection::GetInstance()->Initialize();
// Construct fake output
machine_learning::mojom::TextSuggesterResultPtr result =
GenerateMultiWordResult("hi there completion", 0.5f);
fake_service_connection.SetOutputTextSuggesterResult(std::move(result));
SuggestionsServiceClient client;
base::RunLoop().RunUntilIdle();
base::HistogramTester histogram_tester;
histogram_tester.ExpectTotalCount(
"InputMethod.Assistive.CandidateGenerationTime.MultiWord", 0);
std::vector<TextSuggestion> returned_results;
client.RequestSuggestions(
/*preceding_text=*/"this is some text",
/*suggestion_mode=*/TextSuggestionMode::kCompletion,
/*completion_candidates=*/std::vector<TextCompletionCandidate>{},
/*callback=*/
base::BindLambdaForTesting(
[&](const std::vector<TextSuggestion>& results) {
returned_results = results;
}));
std::vector<TextSuggestion> expected_results = {
TextSuggestion{.mode = TextSuggestionMode::kCompletion,
.type = TextSuggestionType::kMultiWord,
.text = "hi there completion"},
};
base::RunLoop().RunUntilIdle();
EXPECT_EQ(returned_results, expected_results);
histogram_tester.ExpectTotalCount(
"InputMethod.Assistive.CandidateGenerationTime.MultiWord", 1);
}
TEST_F(SuggestionsServiceClientTest, ReturnsPredictionResultsFromMojoService) {
machine_learning::FakeServiceConnectionImpl fake_service_connection;
machine_learning::ServiceConnection::UseFakeServiceConnectionForTesting(
&fake_service_connection);
machine_learning::ServiceConnection::GetInstance()->Initialize();
// Construct fake output
machine_learning::mojom::TextSuggesterResultPtr result =
GenerateMultiWordResult("hi there prediction", 0.5f);
fake_service_connection.SetOutputTextSuggesterResult(std::move(result));
SuggestionsServiceClient client;
base::RunLoop().RunUntilIdle();
base::HistogramTester histogram_tester;
histogram_tester.ExpectTotalCount(
"InputMethod.Assistive.CandidateGenerationTime.MultiWord", 0);
std::vector<TextSuggestion> returned_results;
client.RequestSuggestions(
/*preceding_text=*/"this is some text",
/*suggestion_mode=*/TextSuggestionMode::kPrediction,
/*completion_candidates=*/std::vector<TextCompletionCandidate>{},
/*callback=*/
base::BindLambdaForTesting(
[&](const std::vector<TextSuggestion>& results) {
returned_results = results;
}));
std::vector<TextSuggestion> expected_results = {
TextSuggestion{.mode = TextSuggestionMode::kPrediction,
.type = TextSuggestionType::kMultiWord,
.text = "hi there prediction"},
};
base::RunLoop().RunUntilIdle();
EXPECT_EQ(returned_results, expected_results);
histogram_tester.ExpectTotalCount(
"InputMethod.Assistive.CandidateGenerationTime.MultiWord", 1);
}
} // namespace
} // namespace chromeos