blob: b959273823bf6fadffd320201d10f939743dfa30 [file] [log] [blame]
// Copyright 2020 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include "base/run_loop.h"
#include "base/task/thread_pool/thread_pool_instance.h"
#include "base/test/metrics/histogram_tester.h"
#include "chrome/services/machine_learning/metrics.h"
#include "chrome/services/machine_learning/public/cpp/service_connection.h"
#include "chrome/services/machine_learning/public/cpp/test_support/machine_learning_test_utils.h"
#include "chrome/services/machine_learning/public/mojom/decision_tree.mojom.h"
#include "chrome/services/machine_learning/public/mojom/machine_learning_service.mojom.h"
#include "chrome/test/base/in_process_browser_test.h"
#include "components/metrics/content/subprocess_metrics_provider.h"
#include "content/public/browser/service_process_host.h"
#include "content/public/browser/service_process_info.h"
#include "content/public/test/browser_test.h"
#include "content/public/test/browser_test_utils.h"
#include "mojo/public/cpp/bindings/remote.h"
namespace machine_learning {
namespace {
class ServiceProcessObserver : public content::ServiceProcessHost::Observer {
public:
ServiceProcessObserver() { content::ServiceProcessHost::AddObserver(this); }
~ServiceProcessObserver() override {
content::ServiceProcessHost::RemoveObserver(this);
}
ServiceProcessObserver(const ServiceProcessObserver&) = delete;
ServiceProcessObserver& operator=(const ServiceProcessObserver&) = delete;
// Whether the service is launched.
int IsLaunched() const { return is_launched_; }
// Whether the service is terminated normally.
bool IsTerminated() const { return is_terminated_; }
// Launch |launch_wait_loop_| to wait until a service launch is detected.
void WaitForLaunch() { launch_wait_loop_.Run(); }
// Launch |terminate_wait_loop_| to wait until a normal service termination is
// detected.
void WaitForTerminate() { terminate_wait_loop_.Run(); }
void OnServiceProcessLaunched(
const content::ServiceProcessInfo& info) override {
if (info.IsService<mojom::MachineLearningService>()) {
is_launched_ = true;
if (launch_wait_loop_.running())
launch_wait_loop_.Quit();
}
}
void OnServiceProcessTerminatedNormally(
const content::ServiceProcessInfo& info) override {
if (info.IsService<mojom::MachineLearningService>()) {
is_terminated_ = true;
if (terminate_wait_loop_.running())
terminate_wait_loop_.Quit();
}
}
private:
base::RunLoop launch_wait_loop_;
base::RunLoop terminate_wait_loop_;
bool is_launched_ = false;
bool is_terminated_ = false;
};
// Retries fetching |histogram_name| until it contains at least |count| samples.
void RetryForHistogramUntilCountReached(base::HistogramTester* histogram_tester,
const std::string& histogram_name,
size_t count) {
while (true) {
const std::vector<base::Bucket> buckets =
histogram_tester->GetAllSamples(histogram_name);
size_t total_count = 0;
for (const auto& bucket : buckets)
total_count += bucket.count;
if (total_count >= count)
return;
content::FetchHistogramsFromChildProcesses();
::metrics::SubprocessMetricsProvider::MergeHistogramDeltasForTesting();
base::RunLoop().RunUntilIdle();
}
}
} // namespace
using MachineLearningServiceBrowserTest = InProcessBrowserTest;
IN_PROC_BROWSER_TEST_F(MachineLearningServiceBrowserTest, LaunchAndTerminate) {
base::HistogramTester histogram_tester;
ServiceProcessObserver observer;
auto* service_connection = ServiceConnection::GetInstance();
service_connection->GetService();
observer.WaitForLaunch();
RetryForHistogramUntilCountReached(
&histogram_tester, machine_learning::metrics::kServiceLaunch, 1);
EXPECT_TRUE(service_connection);
EXPECT_TRUE(observer.IsLaunched());
histogram_tester.ExpectTotalCount(machine_learning::metrics::kServiceLaunch,
1);
histogram_tester.ExpectUniqueSample(
machine_learning::metrics::kServiceRequested,
machine_learning::metrics::MLServiceRequestStatus::
kRequestedServiceNotLaunched,
1);
service_connection->ResetServiceForTesting();
observer.WaitForTerminate();
RetryForHistogramUntilCountReached(
&histogram_tester, machine_learning::metrics::kServiceNormalTermination,
1);
EXPECT_TRUE(observer.IsTerminated());
histogram_tester.ExpectTotalCount(
machine_learning::metrics::kServiceNormalTermination, 1);
}
IN_PROC_BROWSER_TEST_F(MachineLearningServiceBrowserTest,
MultipleLaunchesReusesSharedProcess) {
base::HistogramTester histogram_tester;
ServiceProcessObserver observer;
auto* service_connection = ServiceConnection::GetInstance();
auto* service_ptr1 = service_connection->GetService();
observer.WaitForLaunch();
EXPECT_TRUE(service_connection);
EXPECT_TRUE(observer.IsLaunched());
RetryForHistogramUntilCountReached(
&histogram_tester, machine_learning::metrics::kServiceLaunch, 1);
histogram_tester.ExpectTotalCount(machine_learning::metrics::kServiceLaunch,
1);
histogram_tester.ExpectBucketCount(
machine_learning::metrics::kServiceRequested,
machine_learning::metrics::MLServiceRequestStatus::
kRequestedServiceNotLaunched,
1);
auto* service_ptr2 = service_connection->GetService();
EXPECT_EQ(service_ptr1, service_ptr2);
RetryForHistogramUntilCountReached(
&histogram_tester, machine_learning::metrics::kServiceRequested, 2);
histogram_tester.ExpectTotalCount(machine_learning::metrics::kServiceLaunch,
1);
histogram_tester.ExpectBucketCount(
machine_learning::metrics::kServiceRequested,
machine_learning::metrics::MLServiceRequestStatus::
kRequestedServiceLaunched,
1);
}
IN_PROC_BROWSER_TEST_F(MachineLearningServiceBrowserTest,
LoadInvalidDecisionTreeModel) {
base::HistogramTester histogram_tester;
ServiceProcessObserver observer;
auto run_loop = std::make_unique<base::RunLoop>();
auto* service_connection = ServiceConnection::GetInstance();
mojo::Remote<mojom::DecisionTreePredictor> predictor;
mojom::LoadModelResult result = mojom::LoadModelResult::kLoadModelError;
service_connection->LoadDecisionTreeModel(
mojom::DecisionTreeModelSpec::New("Invalid model spec"),
predictor.BindNewPipeAndPassReceiver(),
base::BindOnce(
[](mojom::LoadModelResult* p_result, base::RunLoop* run_loop,
mojom::LoadModelResult result) {
*p_result = result;
run_loop->Quit();
},
&result, run_loop.get()));
run_loop->Run();
EXPECT_TRUE(observer.IsLaunched());
EXPECT_EQ(mojom::LoadModelResult::kModelSpecError, result);
RetryForHistogramUntilCountReached(
&histogram_tester,
machine_learning::metrics::kDecisionTreeModelLoadResult, 1);
histogram_tester.ExpectTotalCount(machine_learning::metrics::kServiceLaunch,
1);
histogram_tester.ExpectUniqueSample(
machine_learning::metrics::kServiceRequested,
machine_learning::metrics::MLServiceRequestStatus::
kRequestedServiceNotLaunched,
1);
histogram_tester.ExpectUniqueSample(
machine_learning::metrics::kDecisionTreeModelLoadResult,
mojom::LoadModelResult::kModelSpecError, 1);
// Flush so that |predictor| becomes aware of potential disconnection.
predictor.FlushForTesting();
EXPECT_FALSE(predictor.is_connected());
}
IN_PROC_BROWSER_TEST_F(MachineLearningServiceBrowserTest,
LoadValidDecisionTreeModelAndPredict) {
base::HistogramTester histogram_tester;
ServiceProcessObserver observer;
auto* service_connection = ServiceConnection::GetInstance();
const mojom::DecisionTreePredictionResult expected_prediction_result =
mojom::DecisionTreePredictionResult::kTrue;
auto model_proto =
testing::GetModelProtoForPredictionResult(expected_prediction_result);
mojo::Remote<mojom::DecisionTreePredictor> predictor;
mojom::LoadModelResult load_result = mojom::LoadModelResult::kLoadModelError;
auto run_loop = std::make_unique<base::RunLoop>();
service_connection->LoadDecisionTreeModel(
mojom::DecisionTreeModelSpec::New(model_proto->SerializeAsString()),
predictor.BindNewPipeAndPassReceiver(),
base::BindOnce(
[](mojom::LoadModelResult* p_result, base::RunLoop* run_loop,
mojom::LoadModelResult result) {
*p_result = result;
run_loop->Quit();
},
&load_result, run_loop.get()));
run_loop->Run();
RetryForHistogramUntilCountReached(
&histogram_tester,
machine_learning::metrics::kDecisionTreeModelLoadResult, 1);
EXPECT_TRUE(observer.IsLaunched());
EXPECT_EQ(mojom::LoadModelResult::kOk, load_result);
histogram_tester.ExpectTotalCount(machine_learning::metrics::kServiceLaunch,
1);
histogram_tester.ExpectTotalCount(
machine_learning::metrics::kDecisionTreeModelValidationLatency, 1);
histogram_tester.ExpectUniqueSample(
machine_learning::metrics::kServiceRequested,
machine_learning::metrics::MLServiceRequestStatus::
kRequestedServiceNotLaunched,
1);
histogram_tester.ExpectUniqueSample(
machine_learning::metrics::kDecisionTreeModelLoadResult,
mojom::LoadModelResult::kOk, 1);
// Flush so that |predictor| becomes aware of potential disconnection.
predictor.FlushForTesting();
EXPECT_TRUE(predictor.is_connected());
auto prediction_result = mojom::DecisionTreePredictionResult::kUnknown;
double prediction_score = 0.0;
// Reset the RunLoop.
run_loop = std::make_unique<base::RunLoop>();
predictor->Predict(
{}, base::BindOnce(
[](mojom::DecisionTreePredictionResult* p_result, double* p_score,
base::RunLoop* runloop,
mojom::DecisionTreePredictionResult result, double score) {
*p_result = result;
*p_score = score;
runloop->Quit();
},
&prediction_result, &prediction_score, run_loop.get()));
run_loop->Run();
predictor.FlushForTesting();
RetryForHistogramUntilCountReached(
&histogram_tester,
machine_learning::metrics::kDecisionTreeModelPredictionResult, 1);
EXPECT_EQ(expected_prediction_result, prediction_result);
EXPECT_GT(prediction_score, 0.0);
histogram_tester.ExpectTotalCount(
machine_learning::metrics::kDecisionTreeModelEvaluationLatency, 1);
histogram_tester.ExpectUniqueSample(
machine_learning::metrics::kDecisionTreeModelPredictionResult,
prediction_result, 1);
}
} // namespace machine_learning