blob: 2937a175c5e3c2a0939080b2d664ef71a210a046 [file] [log] [blame]
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h"
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h"
namespace tflite {
namespace ops {
namespace custom {
TfLiteRegistration* Register_SENTENCEPIECE_TOKENIZER();
TfLiteRegistration* Register_RAGGED_TENSOR_TO_TENSOR();
} // namespace custom
} // namespace ops
} // namespace tflite
namespace tflite {
namespace task {
namespace text {
namespace retrieval {
using ::absl::Status;
using ::absl::StatusCode;
using internal::QAInput;
using internal::QAOutput;
using ::tflite::support::StatusOr;
using ::tflite::support::TfLiteSupportStatus;
using ::tflite::task::core::FindTensorByName;
using ::tflite::task::core::PopulateTensor;
using ::tflite::task::core::PopulateVectorToRepeated;
using ::tflite::task::core::TaskAPIFactory;
using FeatureVector = UniversalSentenceEncoderQA::FeatureVector;
namespace {
constexpr char kQueryTextTensorName[] = "inp_text";
constexpr char kResponseTextTensorName[] = "res_text";
constexpr char kResponseContextTensorName[] = "res_context";
constexpr char kQueryEncodingTensorName[] = "query_encoding";
constexpr char kResponseEncodingTensorName[] = "response_encoding";
// Sanity check for options to ensure required fields.
absl::Status SanityCheckOptions(const RetrievalOptions& options) {
if (!options.has_base_options()) {
return CreateStatusWithPayload(StatusCode::kInvalidArgument,
"Missing mandatory `base_options` field",
TfLiteSupportStatus::kInvalidArgumentError);
}
return absl::OkStatus();
}
// Copy vector from model output.
inline absl::Status CopyVector(const TfLiteTensor* src, FeatureVector* target) {
return PopulateVectorToRepeated(src, target->mutable_value_float());
}
// Dot product of two vectors. Returns error status if size is mismatched.
template <class TCollection, class T = float>
tflite::support::StatusOr<T> Dot(const TCollection& a, const TCollection& b) {
if (a.size() != b.size()) {
return Status(
StatusCode::kInvalidArgument,
absl::StrFormat("mismatched vector size %d != %d", a.size(), b.size()));
}
auto dist = T();
for (size_t i = 0; i < a.size(); ++i) {
dist += T(a[i]) * T(b[i]);
}
return dist;
}
} // namespace
namespace internal {
struct QAInput {
std::string query_text;
std::string response_text;
std::string response_context;
};
struct QAOutput {
// Directly populate from raw tensor pointers to avoid extra copy.
const TfLiteTensor* query_encoding; // not owned.
const TfLiteTensor* response_encoding; // not owned.
};
} // namespace internal
// Creates custom op resolver for USE QA task.
std::unique_ptr<tflite_shims::ops::builtin::BuiltinOpResolver>
CreateQACustomOpResolver() {
auto resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>();
resolver->AddCustom(
"TFSentencepieceTokenizeOp",
::tflite::ops::custom::Register_SENTENCEPIECE_TOKENIZER());
resolver->AddCustom(
"RaggedTensorToTensor",
::tflite::ops::custom::Register_RAGGED_TENSOR_TO_TENSOR());
return resolver;
}
constexpr int UniversalSentenceEncoderQA::kFinalEmbeddingSize;
StatusOr<RetrievalOutput> UniversalSentenceEncoderQA::Retrieve(
const RetrievalInput& input) {
if (input.query_text().empty()) {
return Status(StatusCode::kInvalidArgument, "query text cannot be empty.");
}
if (input.responses().empty()) {
return Status(StatusCode::kInvalidArgument, "responses cannot be empty.");
}
RetrievalOutput output;
// Run inference.
// (1) Query is only encoded for once.
// (2) If responses are raw text, run model to get encoded vectors; otherwise,
// the encoded vector is kept from the input when given.
for (size_t i = 0; i < input.responses_size(); ++i) {
const auto& resp = input.responses(i);
if (resp.has_raw_text()) {
// If response is in th raw text, encode both query and response.
const auto out = Run(input.query_text(), resp.raw_text().text(),
resp.raw_text().context());
// Only encode query for the first time.
if (i == 0) {
RETURN_IF_ERROR(
CopyVector(out.query_encoding, output.mutable_query_encoding()));
}
// For each answer, set the response result.
auto r = output.mutable_response_results()->Add();
RETURN_IF_ERROR(CopyVector(out.response_encoding, r->mutable_encoding()));
} else {
// If response is already encoded, encode query only and keep response
// encoding.
// Only encode query for the first time.
if (i == 0) {
const auto& q = EncodeQuery(input.query_text());
*output.mutable_query_encoding() = q.value();
}
// For each answer, set the response result from text_encoding
auto r = output.mutable_response_results()->Add();
*r->mutable_encoding() = resp.text_encoding();
}
}
// Calculate scores.
for (size_t i = 0; i < output.response_results_size(); ++i) {
auto* r = output.mutable_response_results(i);
// TODO(tianlin): For a large size of results, it is more efficient to use
// matrix multiplication.
const auto& score = Similarity(output.query_encoding(), r->encoding());
if (!score.ok()) {
return score.status();
}
r->set_score(score.value());
}
return output;
}
StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeQuery(
absl::string_view query_text) {
if (query_text.empty()) {
return Status(StatusCode::kInvalidArgument, "query text cannot be empty.");
}
const auto& output = Run(query_text, "", "");
FeatureVector v;
RETURN_IF_ERROR(CopyVector(output.query_encoding, &v));
return v;
}
StatusOr<FeatureVector> UniversalSentenceEncoderQA::EncodeResponse(
absl::string_view response_text,
absl::string_view response_context) {
if (response_text.empty() && response_context.empty()) {
return Status(
StatusCode::kInvalidArgument,
"either response text or context should be set to non-empty.");
}
const auto& output = Run("", response_text, response_context);
FeatureVector v;
RETURN_IF_ERROR(CopyVector(output.response_encoding, &v));
return v;
}
StatusOr<float> UniversalSentenceEncoderQA::Similarity(const FeatureVector& a,
const FeatureVector& b) {
const auto& av = a.value_float();
const auto& bv = b.value_float();
return Dot(av, bv);
}
std::vector<size_t> UniversalSentenceEncoderQA::Top(
const RetrievalOutput& output,
size_t k) {
// Ensure k in [0, total_size).
// If k == 0, it means that all outputs are ranked.
if (k == 0) {
k = output.response_results_size();
} else {
k = std::min(k, size_t(output.response_results_size()));
}
std::vector<size_t> pos(output.response_results_size());
for (size_t i = 0; i < output.response_results_size(); ++i) {
pos[i] = i;
}
const auto greater_score = [&output](size_t i, size_t j) {
return output.response_results(i).score() >
output.response_results(j).score();
};
std::partial_sort(pos.begin(), pos.begin() + k, pos.end(), greater_score);
// Return sorted.
return std::vector<size_t>(pos.begin(), pos.begin() + k);
}
Status UniversalSentenceEncoderQA::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors,
const QAInput& input) {
auto* input_tensor_metadatas =
GetMetadataExtractor()->GetInputTensorMetadata();
TfLiteTensor* query_text_tensor =
input_tensor_metadatas
? FindTensorByName(input_tensors, input_tensor_metadatas,
kQueryTextTensorName)
: input_tensors[0];
TfLiteTensor* response_text_tensor =
input_tensor_metadatas
? FindTensorByName(input_tensors, input_tensor_metadatas,
kResponseTextTensorName)
: input_tensors[2];
TfLiteTensor* response_context_tensor =
input_tensor_metadatas
? FindTensorByName(input_tensors, input_tensor_metadatas,
kResponseContextTensorName)
: input_tensors[1];
RETURN_IF_ERROR(PopulateTensor(input.query_text, query_text_tensor));
RETURN_IF_ERROR(PopulateTensor(input.response_text, response_text_tensor));
RETURN_IF_ERROR(
PopulateTensor(input.response_context, response_context_tensor));
return absl::OkStatus();
}
StatusOr<QAOutput> UniversalSentenceEncoderQA::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
const QAInput& /*input*/) {
auto* output_tensor_metadatas =
GetMetadataExtractor()->GetOutputTensorMetadata();
const TfLiteTensor* output_query_encoding_tensor =
output_tensor_metadatas
? FindTensorByName(output_tensors, output_tensor_metadatas,
kQueryEncodingTensorName)
: output_tensors[0];
const TfLiteTensor* output_response_encoding_tensor =
output_tensor_metadatas
? FindTensorByName(output_tensors, output_tensor_metadatas,
kResponseEncodingTensorName)
: output_tensors[1];
QAOutput output;
output.query_encoding = output_query_encoding_tensor;
output.response_encoding = output_response_encoding_tensor;
return output;
}
internal::QAOutput UniversalSentenceEncoderQA::Run(
absl::string_view query_text,
absl::string_view response_text,
absl::string_view response_context) {
QAInput input;
input.query_text = query_text;
input.response_text = response_text;
input.response_context = response_context;
return Infer(input).value();
}
StatusOr<std::unique_ptr<UniversalSentenceEncoderQA>>
UniversalSentenceEncoderQA::CreateFromOption(
const RetrievalOptions& options,
std::unique_ptr<tflite::OpResolver> resolver) {
RETURN_IF_ERROR(SanityCheckOptions(options));
// Copy options to ensure the ExternalFile outlives the duration of this
// created object.
auto options_copy = absl::make_unique<RetrievalOptions>(options);
ASSIGN_OR_RETURN(
auto encoder,
TaskAPIFactory::CreateFromBaseOptions<UniversalSentenceEncoderQA>(
&options_copy->base_options(), std::move(resolver)));
encoder->proto_options_ = std::move(options_copy);
return std::move(encoder);
}
} // namespace retrieval
} // namespace text
} // namespace task
} // namespace tflite