blob: cf1fd38e4c853b8830fc0c5ea380459c29869c70 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef SERVICES_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDER_EXECUTION_TASK_H_
#define SERVICES_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDER_EXECUTION_TASK_H_
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"
namespace passage_embeddings {
using OutputType = std::vector<float>;
using InputType = const std::vector<int>&;
class PassageEmbedderExecutionTask
: public tflite::task::core::BaseTaskApi<OutputType, InputType> {
public:
explicit PassageEmbedderExecutionTask(
std::unique_ptr<tflite::task::core::TfLiteEngine> tflite_engine);
~PassageEmbedderExecutionTask() override;
std::optional<OutputType> Execute(InputType input);
protected:
// BaseTaskApi:
absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
InputType input) override;
tflite::support::StatusOr<OutputType> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
InputType input) override;
};
} // namespace passage_embeddings
#endif // SERVICES_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDER_EXECUTION_TASK_H_