| From b96b199069e844cbb8dc428adee741fa40e76af8 Mon Sep 17 00:00:00 2001 |
| From: Xiang Xiao <xiangxiao@google.com> |
| Date: Tue, 10 Sep 2024 14:51:46 -0700 |
| Subject: [PATCH] Implement the GetDependencyHeads method for |
| DependencyParserModel. |
| |
| This method runs a TFLite dependency parser model on a string and |
| returns a vector of dependency head for each token in the string. |
| |
| Bug: 339037155 |
| Change-Id: I4a89ad43849fd3fa85b82d01d2cad915aa78e174 |
| --- |
| third_party/tensorflow-text/BUILD.gn | 2 + |
| .../shims/tensorflow/core/lib/core/errors.h | 45 +++++++++++++++++++ |
| .../shims/tensorflow/core/lib/core/status.h | 14 ++++++ |
| .../shims/tensorflow/core/platform/logging.h | 8 ++++ |
| .../tensorflow_text/core/kernels/mst_solver.h | 7 ++- |
| 5 files changed, 72 insertions(+), 4 deletions(-) |
| create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h |
| create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h |
| create mode 100644 third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h |
| |
| diff --git a/third_party/tensorflow-text/BUILD.gn b/third_party/tensorflow-text/BUILD.gn |
| index 9a8d5fef1c5dc..7c5fa1b37c3b4 100644 |
| --- a/third_party/tensorflow-text/BUILD.gn |
| +++ b/third_party/tensorflow-text/BUILD.gn |
| @@ -17,6 +17,8 @@ config("tensorflow-text-flags") { |
| |
| static_library("tensorflow-text") { |
| sources = [ |
| + "src/tensorflow_text/core/kernels/disjoint_set_forest.h", |
| + "src/tensorflow_text/core/kernels/mst_solver.h", |
| "src/tensorflow_text/core/kernels/regex_split.cc", |
| "src/tensorflow_text/core/kernels/regex_split.h", |
| "src/tensorflow_text/core/kernels/wordpiece_tokenizer.cc", |
| diff --git a/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h |
| new file mode 100644 |
| index 0000000000000..0475248b26b25 |
| --- /dev/null |
| +++ b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/errors.h |
| @@ -0,0 +1,45 @@ |
| +// Copyright 2024 The Chromium Authors |
| +// Use of this source code is governed by a BSD-style license that can be |
| +// found in the LICENSE file. |
| + |
| +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_ |
| +#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_ |
| + |
| +#include "absl/status/status.h" |
| +#include "absl/strings/str_cat.h" |
| + |
| +#define TF_PREDICT_FALSE(x) (x) |
| +#define TF_PREDICT_TRUE(x) (x) |
| + |
| +// For propagating errors when calling a function. |
| +#define TF_RETURN_IF_ERROR(...) \ |
| + do { \ |
| + ::absl::Status _status = (__VA_ARGS__); \ |
| + if (TF_PREDICT_FALSE(!_status.ok())) { \ |
| + return _status; \ |
| + } \ |
| + } while (0) |
| + |
| +namespace tensorflow::errors { |
| + |
| +template <typename Arg1, typename Arg2> |
| +::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2) { |
| + return absl::Status(absl::StatusCode::kInvalidArgument, |
| + absl::StrCat(arg1, arg2)); |
| +} |
| + |
| +template <typename Arg1, typename Arg2, typename Arg3, typename Arg4> |
| +::absl::Status InvalidArgument(Arg1 arg1, Arg2 arg2, Arg3 arg3, Arg4 arg4) { |
| + return absl::Status(absl::StatusCode::kInvalidArgument, |
| + absl::StrCat(arg1, arg2, arg3, arg4)); |
| +} |
| + |
| +template <typename... Args> |
| +absl::Status FailedPrecondition(Args... args) { |
| + return absl::Status(absl::StatusCode::kFailedPrecondition, |
| + absl::StrCat(args...)); |
| +} |
| + |
| +} // namespace tensorflow::errors |
| + |
| +#endif // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_ERRORS_H_ |
| diff --git a/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h |
| new file mode 100644 |
| index 0000000000000..8b16ba1f64c09 |
| --- /dev/null |
| +++ b/third_party/tensorflow-text/shims/tensorflow/core/lib/core/status.h |
| @@ -0,0 +1,14 @@ |
| +// Copyright 2024 The Chromium Authors |
| +// Use of this source code is governed by a BSD-style license that can be |
| +// found in the LICENSE file. |
| + |
| +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_ |
| +#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_ |
| + |
| +#include "absl/status/status.h" |
| + |
| +namespace tensorflow { |
| +using Status = ::absl::Status; |
| +} |
| + |
| +#endif // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_LIB_CORE_STATUS_H_ |
| diff --git a/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h b/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h |
| new file mode 100644 |
| index 0000000000000..776917efba118 |
| --- /dev/null |
| +++ b/third_party/tensorflow-text/shims/tensorflow/core/platform/logging.h |
| @@ -0,0 +1,8 @@ |
| +// Copyright 2024 The Chromium Authors |
| +// Use of this source code is governed by a BSD-style license that can be |
| +// found in the LICENSE file. |
| + |
| +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_ |
| +#define THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_ |
| + |
| +#endif // THIRD_PARTY_TENSORFLOW_TEXT_SHIMS_TENSORFLOW_CORE_PLATFORM_LOGGING_H_ |
| diff --git a/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h b/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h |
| index 818518bc2fdef..50372c5853d27 100644 |
| --- a/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h |
| +++ b/third_party/tensorflow-text/src/tensorflow_text/core/kernels/mst_solver.h |
| @@ -503,9 +503,9 @@ void MstSolver<Index, Score>::ContractCycle(Index node) { |
| for (const auto &node_and_arc : cycle_) { |
| // Set the |score_offset| to the cost of breaking the cycle by replacing the |
| // arc currently directed into the |cycle_node|. |
| - const Index cycle_node = node_and_arc.first; |
| + const Index current_cycle_node = node_and_arc.first; |
| const Score score_offset = -node_and_arc.second->score; |
| - MergeInboundArcs(cycle_node, score_offset, contracted_node); |
| + MergeInboundArcs(current_cycle_node, score_offset, contracted_node); |
| } |
| } |
| |
| @@ -586,7 +586,6 @@ tensorflow::Status MstSolver<Index, Score>::ExpansionPhase( |
| argmax[target] = arc.source - 1; |
| } |
| } |
| - DCHECK_GE(num_roots, 1); |
| |
| // Even when |forest_| is false, |num_roots| can still be more than 1. While |
| // the root score penalty discourages structures with multiple root arcs, it |
| @@ -595,7 +594,7 @@ tensorflow::Status MstSolver<Index, Score>::ExpansionPhase( |
| // produce an all-root structure in spite of the root score penalty. As this |
| // example illustrates, however, |num_roots| will be more than 1 if and only |
| // if the original digraph is infeasible for trees. |
| - if (!forest_ && num_roots != 1) { |
| + if (num_roots < 1 || (!forest_ && num_roots != 1)) { |
| return tensorflow::errors::FailedPrecondition("Infeasible digraph"); |
| } |
| |
| -- |
| 2.46.0.662.g92d0881bb0-goog |
| |