Switches `bert_clu_annotator` build rules to `cc_library_with_tflite`.
PiperOrigin-RevId: 452161028
diff --git a/tensorflow_lite_support/cc/task/text/BUILD b/tensorflow_lite_support/cc/task/text/BUILD
index 8970273..c74a627 100644
--- a/tensorflow_lite_support/cc/task/text/BUILD
+++ b/tensorflow_lite_support/cc/task/text/BUILD
@@ -138,19 +138,21 @@
],
)
-cc_library(
+cc_library_with_tflite(
name = "clu_annotator",
hdrs = [
"clu_annotator.h",
],
- deps = [
+ tflite_deps = [
"//tensorflow_lite_support/cc/task/core:base_task_api",
"//tensorflow_lite_support/cc/task/core:tflite_engine",
+ ],
+ deps = [
"//tensorflow_lite_support/cc/task/text/proto:clu_proto_inc",
],
)
-cc_library(
+cc_library_with_tflite(
name = "bert_clu_annotator",
srcs = [
"bert_clu_annotator.cc",
@@ -158,12 +160,14 @@
hdrs = [
"bert_clu_annotator.h",
],
- deps = [
+ tflite_deps = [
":clu_annotator",
- "//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/task/core:task_api_factory",
- "//tensorflow_lite_support/cc/task/core:task_utils",
"//tensorflow_lite_support/cc/task/text/clu_lib:tflite_modules",
+ ],
+ deps = [
+ "//tensorflow_lite_support/cc/port:status_macros",
+ "//tensorflow_lite_support/cc/task/core:task_utils",
"//tensorflow_lite_support/cc/task/text/proto:bert_clu_annotator_options_proto_inc",
"//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer",
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils",
diff --git a/tensorflow_lite_support/cc/task/text/clu_lib/BUILD b/tensorflow_lite_support/cc/task/text/clu_lib/BUILD
index 13e63db..6fb0526 100644
--- a/tensorflow_lite_support/cc/task/text/clu_lib/BUILD
+++ b/tensorflow_lite_support/cc/task/text/clu_lib/BUILD
@@ -1,12 +1,20 @@
+load(
+ "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl",
+ "cc_library_with_tflite",
+)
+
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
-cc_library(
+cc_library_with_tflite(
name = "tflite_modules",
srcs = ["tflite_modules.cc"],
hdrs = ["tflite_modules.h"],
+ tflite_deps = [
+ "//tensorflow_lite_support/cc/task/core:tflite_engine",
+ ],
deps = [
":bert_utils",
":constants",
diff --git a/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc b/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
index c16f5bc..784bdb1 100644
--- a/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
+++ b/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
@@ -42,8 +42,8 @@
const CluRequest& request, int token_id_tensor_idx,
int token_mask_tensor_idx, int token_type_id_tensor_idx,
const tflite::support::text::tokenizer::BertTokenizer* tokenizer,
- size_t max_seq_len, int max_history_turns, tflite::Interpreter* interpreter,
- Artifacts* artifacts) {
+ size_t max_seq_len, int max_history_turns,
+ core::TfLiteEngine::Interpreter* interpreter, Artifacts* artifacts) {
size_t seq_len;
int64_t* tokens_tensor =
interpreter->typed_input_tensor<int64_t>(token_id_tensor_idx);
@@ -116,8 +116,9 @@
return absl::OkStatus();
}
-absl::StatusOr<int> GetInputSeqDimSize(const size_t input_idx,
- const tflite::Interpreter* interpreter) {
+absl::StatusOr<int> GetInputSeqDimSize(
+ const size_t input_idx,
+ const core::TfLiteEngine::Interpreter* interpreter) {
if (input_idx >= interpreter->inputs().size()) {
return absl::InternalError(absl::StrCat(
"input_idx should be less than interpreter input numbers. ", input_idx,
@@ -132,14 +133,15 @@
return tflite::SizeOfDimension(tensor, 1);
}
-absl::Status AbstractModule::Init(tflite::Interpreter* interpreter,
+absl::Status AbstractModule::Init(core::TfLiteEngine::Interpreter* interpreter,
const BertCluAnnotatorOptions* options) {
interpreter_ = interpreter;
return absl::OkStatus();
}
absl::StatusOr<std::unique_ptr<AbstractModule>> UtteranceSeqModule::Create(
- tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
+ core::TfLiteEngine::Interpreter* interpreter,
+ const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options,
const tflite::support::text::tokenizer::BertTokenizer* tokenizer) {
auto out = std::make_unique<UtteranceSeqModule>();
@@ -187,7 +189,8 @@
}
absl::StatusOr<std::unique_ptr<AbstractModule>> DomainModule::Create(
- tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
+ core::TfLiteEngine::Interpreter* interpreter,
+ const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options) {
auto out = std::make_unique<DomainModule>();
out->tensor_index_map_ = tensor_index_map;
@@ -215,7 +218,8 @@
}
absl::StatusOr<std::unique_ptr<AbstractModule>> IntentModule::Create(
- tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
+ core::TfLiteEngine::Interpreter* interpreter,
+ const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options) {
auto out = std::make_unique<IntentModule>();
out->tensor_index_map_ = tensor_index_map;
@@ -261,7 +265,8 @@
}
absl::StatusOr<std::unique_ptr<AbstractModule>> SlotModule::Create(
- tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
+ core::TfLiteEngine::Interpreter* interpreter,
+ const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options) {
auto out = std::make_unique<SlotModule>();
out->tensor_index_map_ = tensor_index_map;
diff --git a/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h b/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
index 5a9f183..d9c74a4 100644
--- a/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
+++ b/tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
@@ -18,7 +18,7 @@
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
-#include "tensorflow/lite/interpreter.h"
+#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "tensorflow_lite_support/cc/task/text/proto/bert_clu_annotator_options_proto_inc.h"
#include "tensorflow_lite_support/cc/task/text/proto/clu_proto_inc.h"
#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
@@ -76,7 +76,7 @@
protected:
AbstractModule() = default;
- absl::Status Init(Interpreter* interpreter,
+ absl::Status Init(core::TfLiteEngine::Interpreter* interpreter,
const BertCluAnnotatorOptions* options);
using NamesAndConfidences =
@@ -88,7 +88,7 @@
int names_tensor_idx, int scores_tensor_idx) const;
// TFLite interpreter
- Interpreter* interpreter_ = nullptr;
+ core::TfLiteEngine::Interpreter* interpreter_ = nullptr;
const TensorIndexMap* tensor_index_map_ = nullptr;
};
@@ -98,7 +98,8 @@
class UtteranceSeqModule : public AbstractModule {
public:
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
- Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
+ core::TfLiteEngine::Interpreter* interpreter,
+ const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options,
const tflite::support::text::tokenizer::BertTokenizer* tokenizer);
@@ -116,7 +117,8 @@
class DomainModule : public AbstractModule {
public:
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
- Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
+ core::TfLiteEngine::Interpreter* interpreter,
+ const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options);
absl::Status Postprocess(Artifacts* artifacts,
@@ -130,7 +132,8 @@
class IntentModule : public AbstractModule {
public:
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
- Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
+ core::TfLiteEngine::Interpreter* interpreter,
+ const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options);
absl::Status Postprocess(Artifacts* artifacts,
@@ -145,7 +148,8 @@
class SlotModule : public AbstractModule {
public:
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
- Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
+ core::TfLiteEngine::Interpreter* interpreter,
+ const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options);
absl::Status Postprocess(Artifacts* artifacts,