blob: 1954aa33af74af9ba8f9da98d7f54fddfca7d2d6 [file] [log] [blame]
/*
* Copyright 2024 The ChromiumOS Authors
* Use of this source code is governed by a BSD-style license that can be
* found in the LICENSE file.
*/
#ifndef COMMON_SIMPLE_MODEL_BUILDER_H_
#define COMMON_SIMPLE_MODEL_BUILDER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/core/c/c_api_types.h"
#include "tensorflow/lite/core/model_builder.h"
namespace tflite::cros {
// A helper class to simplify the model building process so we don't need to
// manipulate raw FlatBuffer types. This class is thread-compatible.
class SimpleModelBuilder {
public:
static constexpr char kSignatureKey[] = "serving_default";
struct TensorArgs {
// TODO(shik): Generate SignatureDef using the tensor names. It's not
// supported by the ModelWriter we are using now so we have to do it
// ourselves.
std::string name;
TfLiteType type = kTfLiteNoType;
std::vector<int> shape;
// Index of buffer returned by AddBuffer() for constant data.
// The 0-th buffer is a sentinel empty buffer per TFLite schema, which means
// no buffer.
int buffer = 0;
// TODO(shik): Support quantizations.
};
template <typename T>
struct OperatorArgs {
TfLiteBuiltinOperator op;
// Input/Output tensor indices.
// TODO(shik): Support intermediates tensors. It's not exposed by
// Interpreter API we have to use Subgraph API directly in implementation.
std::vector<int> inputs;
std::vector<int> outputs;
// The corresponding params for the given operator. For example, it should
// be TfLiteAddParams for kTfLiteBuiltinAdd. See
// tensorflow/lite/core/c/builtin_op_data.h for their names.
T params;
};
SimpleModelBuilder();
// Move-only.
SimpleModelBuilder(SimpleModelBuilder&& other) = default;
SimpleModelBuilder& operator=(SimpleModelBuilder&& other) = default;
SimpleModelBuilder(const SimpleModelBuilder&) = delete;
SimpleModelBuilder& operator=(const SimpleModelBuilder&) = delete;
// Adds an input tensor. Returns the tensor index.
int AddInput(const TensorArgs& args);
// Adds an output tensor. Returns the tensor index.
int AddOutput(const TensorArgs& args);
// Adds an internal tensor. Returns the tensor index.
int AddInternalTensor(const TensorArgs& args);
// Adds a buffer with data. Returns the buffer index.
int AddBuffer(std::vector<uint8_t> data);
// Adds an operator node in the graph.
template <typename T>
void AddOperator(const OperatorArgs<T>& args) {
// We have to use malloc here since this will be freed by Interpreter.
void* builtin_data = malloc(sizeof(T));
*static_cast<T*>(builtin_data) = args.params;
AddOperatorImpl({
.op = args.op,
.inputs = args.inputs,
.outputs = args.outputs,
.params = ScopedBuiltinData(builtin_data),
});
}
// Builds the model. This can be called multiple times.
// TODO(shik): Support saving to a file directly.
std::unique_ptr<FlatBufferModel> Build();
private:
struct FreeDeleter {
void operator()(void* ptr) { free(ptr); }
};
using ScopedBuiltinData = std::unique_ptr<void, FreeDeleter>;
void AddOperatorImpl(OperatorArgs<ScopedBuiltinData> args);
int next_tensor_index_ = 0;
// The added tensors stored as vectors of (tensor_index, tensor_args) pair.
std::vector<std::pair<int, TensorArgs>> inputs_;
std::vector<std::pair<int, TensorArgs>> outputs_;
std::vector<std::pair<int, TensorArgs>> internal_tensors_;
// The added buffers. The first one will be an empty vector to match TFLite
// schema.
std::vector<std::vector<uint8_t>> buffers_;
// The added operators with params converted to owned void data pointers.
std::vector<OperatorArgs<ScopedBuiltinData>> operators_;
};
} // namespace tflite::cros
#endif // COMMON_SIMPLE_MODEL_BUILDER_H_