blob: df5b778d4d6635d685a0794192e9868c554d3a69 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROMEOS_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_FAKE_SERVICE_CONNECTION_H_
#define CHROMEOS_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_FAKE_SERVICE_CONNECTION_H_
#include <memory>
#include <vector>
#include "base/macros.h"
#include "chromeos/services/machine_learning/public/cpp/service_connection.h"
#include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/model.mojom.h"
#include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h"
#include "mojo/public/cpp/bindings/binding_set.h"
namespace chromeos {
namespace machine_learning {
// Fake implementation of chromeos::machine_learning::ServiceConnection.
// Handles LoadModel (and Model::CreateGraphExecutor) by binding to itself.
// Handles GraphExecutor::Execute by always returning the value specified by
// a previous call to SetOutputValue.
// For use with ServiceConnection::UseFakeServiceConnectionForTesting().
class FakeServiceConnectionImpl : public ServiceConnection,
public mojom::Model,
public mojom::GraphExecutor {
public:
FakeServiceConnectionImpl();
~FakeServiceConnectionImpl() override;
// It's safe to execute LoadModel for multi times, but all the requests will
// be bound to the same instance.
void LoadModel(
mojom::ModelSpecPtr spec,
mojom::ModelRequest request,
mojom::MachineLearningService::LoadModelCallback callback) override;
// mojom::Model:
void CreateGraphExecutor(
mojom::GraphExecutorRequest request,
mojom::Model::CreateGraphExecutorCallback callback) override;
// mojom::GraphExecutor:
// Execute() will return the tensor set by SetOutputValue() as the output.
void Execute(base::flat_map<std::string, mojom::TensorPtr> inputs,
const std::vector<std::string>& output_names,
mojom::GraphExecutor::ExecuteCallback callback) override;
// Call SetOutputValue() before Execute() to set the output tensor.
void SetOutputValue(const std::vector<int64_t>& shape,
const std::vector<double>& value);
private:
mojo::BindingSet<mojom::Model> model_bindings_;
mojo::BindingSet<mojom::GraphExecutor> graph_bindings_;
mojom::TensorPtr execute_result_;
DISALLOW_COPY_AND_ASSIGN(FakeServiceConnectionImpl);
};
} // namespace machine_learning
} // namespace chromeos
#endif // CHROMEOS_SERVICES_MACHINE_LEARNING_PUBLIC_CPP_FAKE_SERVICE_CONNECTION_H_