blob: a3d2bbd87203544a49dd67e961af1359f56b2905 [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "content/browser/ml/ml_service_impl_cros.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "chromeos/services/machine_learning/public/cpp/fake_service_connection.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "components/ml/mojom/ml_service.mojom.h"
#include "content/public/test/test_renderer_host.h"
#include "mojo/public/cpp/base/big_buffer.h"
#include "mojo/public/cpp/bindings/pending_receiver.h"
#include "mojo/public/cpp/bindings/pending_remote.h"
#include "mojo/public/cpp/bindings/remote.h"
namespace content {
class MLServiceImplCrOSTest : public RenderViewHostTestHarness {
public:
void SetUp() override {
RenderViewHostTestHarness::SetUp();
chromeos::machine_learning::ServiceConnection::
UseFakeServiceConnectionForTesting(&fake_ml_service_connection_);
chromeos::machine_learning::ServiceConnection::GetInstance()->Initialize();
}
chromeos::machine_learning::FakeServiceConnectionImpl&
GetMlServiceConnection() {
return fake_ml_service_connection_;
}
private:
chromeos::machine_learning::FakeServiceConnectionImpl
fake_ml_service_connection_;
};
// Tests the successful case model loader creation.
TEST_F(MLServiceImplCrOSTest, CreateModelLoaderOK) {
mojo::Remote<ml::model_loader::mojom::MLService> service_remote;
CrOSMLServiceImpl::Create(service_remote.BindNewPipeAndPassReceiver());
auto options = ml::model_loader::mojom::CreateModelLoaderOptions::New();
options->num_threads = 2;
options->model_format = ml::model_loader::mojom::ModelFormat::kTfLite;
GetMlServiceConnection().SetCreateWebPlatformModelLoaderResult(
ml::model_loader::mojom::CreateModelLoaderResult::kOk);
bool is_callback_called = false;
base::RunLoop run_loop;
service_remote->CreateModelLoader(
std::move(options),
base::BindLambdaForTesting(
[&](ml::model_loader::mojom::CreateModelLoaderResult result,
mojo::PendingRemote<ml::model_loader::mojom::ModelLoader>
remote) {
EXPECT_EQ(result,
ml::model_loader::mojom::CreateModelLoaderResult::kOk);
is_callback_called = true;
run_loop.Quit();
}));
run_loop.Run();
EXPECT_TRUE(is_callback_called);
}
// Tests the failure case model loader creation.
TEST_F(MLServiceImplCrOSTest, CreateModelLoaderFailed) {
mojo::Remote<ml::model_loader::mojom::MLService> service_remote;
CrOSMLServiceImpl::Create(service_remote.BindNewPipeAndPassReceiver());
auto options = ml::model_loader::mojom::CreateModelLoaderOptions::New();
options->num_threads = 2;
options->model_format = ml::model_loader::mojom::ModelFormat::kTfLite;
GetMlServiceConnection().SetCreateWebPlatformModelLoaderResult(
ml::model_loader::mojom::CreateModelLoaderResult::kUnknownError);
bool is_callback_called = false;
base::RunLoop run_loop;
service_remote->CreateModelLoader(
std::move(options),
base::BindLambdaForTesting(
[&](ml::model_loader::mojom::CreateModelLoaderResult result,
mojo::PendingRemote<ml::model_loader::mojom::ModelLoader>
remote) {
EXPECT_EQ(result, ml::model_loader::mojom::CreateModelLoaderResult::
kUnknownError);
is_callback_called = true;
run_loop.Quit();
}));
run_loop.Run();
EXPECT_TRUE(is_callback_called);
}
// Tests the failure case of model loading.
TEST_F(MLServiceImplCrOSTest, LoadModelFailed) {
mojo::Remote<ml::model_loader::mojom::MLService> service_remote;
CrOSMLServiceImpl::Create(service_remote.BindNewPipeAndPassReceiver());
auto options = ml::model_loader::mojom::CreateModelLoaderOptions::New();
options->num_threads = 2;
options->model_format = ml::model_loader::mojom::ModelFormat::kTfLite;
GetMlServiceConnection().SetCreateWebPlatformModelLoaderResult(
ml::model_loader::mojom::CreateModelLoaderResult::kOk);
mojo::Remote<ml::model_loader::mojom::ModelLoader> loader_remote;
bool is_callback_called = false;
base::RunLoop run_loop_create_loader;
service_remote->CreateModelLoader(
std::move(options),
base::BindLambdaForTesting(
[&](ml::model_loader::mojom::CreateModelLoaderResult result,
mojo::PendingRemote<ml::model_loader::mojom::ModelLoader>
remote) {
EXPECT_EQ(result,
ml::model_loader::mojom::CreateModelLoaderResult::kOk);
loader_remote.Bind(std::move(remote));
is_callback_called = true;
run_loop_create_loader.Quit();
}));
run_loop_create_loader.Run();
EXPECT_TRUE(is_callback_called);
// Loads a model.
GetMlServiceConnection().SetLoadWebPlatformModelResult(
ml::model_loader::mojom::LoadModelResult::kUnknownError);
mojo_base::BigBuffer model_buffer;
base::RunLoop run_loop_load_model;
is_callback_called = false;
loader_remote->Load(
std::move(model_buffer),
base::BindLambdaForTesting(
[&](ml::model_loader::mojom::LoadModelResult result,
mojo::PendingRemote<ml::model_loader::mojom::Model> remote,
ml::model_loader::mojom::ModelInfoPtr info) {
EXPECT_EQ(result,
ml::model_loader::mojom::LoadModelResult::kUnknownError);
EXPECT_TRUE(info.is_null());
is_callback_called = true;
run_loop_load_model.Quit();
}));
run_loop_load_model.Run();
EXPECT_TRUE(is_callback_called);
}
// Tests the case that successfully loads a model and does a computation.
TEST_F(MLServiceImplCrOSTest, LoadModelAndCompute) {
mojo::Remote<ml::model_loader::mojom::MLService> service_remote;
CrOSMLServiceImpl::Create(service_remote.BindNewPipeAndPassReceiver());
auto options = ml::model_loader::mojom::CreateModelLoaderOptions::New();
options->num_threads = 2;
options->model_format = ml::model_loader::mojom::ModelFormat::kTfLite;
GetMlServiceConnection().SetCreateWebPlatformModelLoaderResult(
ml::model_loader::mojom::CreateModelLoaderResult::kOk);
mojo::Remote<ml::model_loader::mojom::ModelLoader> loader_remote;
bool is_callback_called = false;
base::RunLoop run_loop_create_loader;
service_remote->CreateModelLoader(
std::move(options),
base::BindLambdaForTesting(
[&](ml::model_loader::mojom::CreateModelLoaderResult result,
mojo::PendingRemote<ml::model_loader::mojom::ModelLoader>
remote) {
EXPECT_EQ(result,
ml::model_loader::mojom::CreateModelLoaderResult::kOk);
loader_remote.Bind(std::move(remote));
is_callback_called = true;
run_loop_create_loader.Quit();
}));
run_loop_create_loader.Run();
EXPECT_TRUE(is_callback_called);
// Loads a model.
GetMlServiceConnection().SetLoadWebPlatformModelResult(
ml::model_loader::mojom::LoadModelResult::kOk);
auto model_info = ml::model_loader::mojom::ModelInfo::New();
model_info->input_tensor_info["input_test"] =
ml::model_loader::mojom::TensorInfo::New(
123, ml::model_loader::mojom::DataType::kBool,
std::vector<uint32_t>{1, 2, 3});
model_info->output_tensor_info["output_test1"] =
ml::model_loader::mojom::TensorInfo::New(
321, ml::model_loader::mojom::DataType::kInt32,
std::vector<uint32_t>{10, 20, 30});
model_info->output_tensor_info["output_test2"] =
ml::model_loader::mojom::TensorInfo::New(
567, ml::model_loader::mojom::DataType::kFloat32,
std::vector<uint32_t>{8, 10, 12});
GetMlServiceConnection().SetWebPlatformModelInfo(std::move(model_info));
mojo::Remote<ml::model_loader::mojom::Model> model_remote;
mojo_base::BigBuffer model_buffer;
base::RunLoop run_loop_load_model;
is_callback_called = false;
loader_remote->Load(
std::move(model_buffer),
base::BindLambdaForTesting(
[&](ml::model_loader::mojom::LoadModelResult result,
mojo::PendingRemote<ml::model_loader::mojom::Model> remote,
ml::model_loader::mojom::ModelInfoPtr info) {
EXPECT_EQ(result, ml::model_loader::mojom::LoadModelResult::kOk);
EXPECT_FALSE(info.is_null());
ASSERT_EQ(info->input_tensor_info.size(), 1u);
ASSERT_TRUE(info->input_tensor_info.contains("input_test"));
EXPECT_EQ(info->input_tensor_info["input_test"]->byte_size, 123u);
EXPECT_EQ(info->input_tensor_info["input_test"]->data_type,
ml::model_loader::mojom::DataType::kBool);
EXPECT_EQ(info->input_tensor_info["input_test"]->dimensions,
std::vector<uint32_t>({1, 2, 3}));
ASSERT_EQ(info->output_tensor_info.size(), 2u);
ASSERT_TRUE(info->output_tensor_info.contains("output_test1"));
EXPECT_EQ(info->output_tensor_info["output_test1"]->byte_size,
321u);
EXPECT_EQ(info->output_tensor_info["output_test1"]->data_type,
ml::model_loader::mojom::DataType::kInt32);
EXPECT_EQ(info->output_tensor_info["output_test1"]->dimensions,
std::vector<uint32_t>({10, 20, 30}));
ASSERT_TRUE(info->output_tensor_info.contains("output_test2"));
ASSERT_EQ(info->output_tensor_info["output_test2"]->byte_size,
567u);
EXPECT_EQ(info->output_tensor_info["output_test2"]->data_type,
ml::model_loader::mojom::DataType::kFloat32);
EXPECT_EQ(info->output_tensor_info["output_test2"]->dimensions,
std::vector<uint32_t>({8, 10, 12}));
model_remote.Bind(std::move(remote));
is_callback_called = true;
run_loop_load_model.Quit();
}));
run_loop_load_model.Run();
EXPECT_TRUE(is_callback_called);
// Does a computation.
GetMlServiceConnection().SetWebPlatformModelComputeResult(
ml::model_loader::mojom::ComputeResult::kOk);
base::flat_map<std::string, std::vector<uint8_t>> compute_output;
compute_output["some output1"] = std::vector<uint8_t>({123, 23, 21});
compute_output["some output2"] = std::vector<uint8_t>({8, 5, 2});
GetMlServiceConnection().SetOutputWebPlatformModelCompute(
std::move(compute_output));
base::flat_map<std::string, std::vector<uint8_t>> compute_input;
compute_input["random input"] = std::vector<uint8_t>({0, 3, 5});
base::RunLoop run_loop_compute;
is_callback_called = false;
model_remote->Compute(
std::move(compute_input),
base::BindLambdaForTesting(
[&](ml::model_loader::mojom::ComputeResult result,
const absl::optional<base::flat_map<
std::string, std::vector<uint8_t>>>& output_tensors) {
ASSERT_TRUE(output_tensors.has_value());
ASSERT_EQ(output_tensors->size(), 2u);
ASSERT_TRUE(output_tensors->contains("some output1"));
EXPECT_EQ(output_tensors->find("some output1")->second,
std::vector<uint8_t>({123, 23, 21}));
ASSERT_TRUE(output_tensors->contains("some output2"));
EXPECT_EQ(output_tensors->find("some output2")->second,
std::vector<uint8_t>({8, 5, 2}));
is_callback_called = true;
run_loop_compute.Quit();
}));
run_loop_compute.Run();
EXPECT_TRUE(is_callback_called);
}
} // namespace content